[
  {
    "path": ".github/workflows/ci.yml",
    "content": "name: Build and test\n\non:\n  create:\n  workflow_dispatch:\n  push:\n    branches:\n      - master\n  pull_request:\n    branches:\n      - master\n\njobs:\n  build-and-test-cpu:\n    strategy:\n      matrix:\n        os: [ubuntu-latest, macos-latest, windows-latest]\n\n    runs-on: ${{ matrix.os }}\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Install OpenMP\n        if: matrix.os != 'windows-latest'\n        run: |\n          if [ \"${{ runner.os }}\" == \"Linux\" ]; then\n            sudo apt-get update && sudo apt-get install -y libomp-dev\n          elif [ \"${{ runner.os }}\" == \"macOS\" ]; then\n            brew install libomp\n          fi\n\n      - name: Install dependencies\n        run: pip install -r requirements.txt\n\n      - name: Run preprocessing\n        run: python dev/data/tinyshakespeare.py\n\n      - name: Train model\n        run: python train_gpt2.py --device=cpu\n\n      - name: Download Win32 Make.exe\n        if: matrix.os == 'windows-latest'\n        run: |\n            $wc = New-Object System.Net.WebClient\n            $url = 'https://github.com/maweil/MakeForWindows/releases/download/v4.4.1/make-bin-win64.zip'\n            $output = './make-bin-win64.zip'\n            $wc.DownloadFile($url, $output)\n\n      - name: Unzip Win32 Makefile\n        if: matrix.os == 'windows-latest'\n        run: |\n          unzip make-bin-win64.zip\n\n      - name: Compile training and testing program\n        if: matrix.os != 'windows-latest'\n        run: make test_gpt2 train_gpt2\n\n      - name: Compile training and testing program for Windows\n        if: matrix.os == 'windows-latest'\n        shell: cmd\n        run: |\n          call \"C:\\\\Program Files\\\\Microsoft Visual Studio\\\\2022\\\\Enterprise\\\\VC\\\\Auxiliary\\\\Build\\\\vcvars64.bat\"\n          make-4.4.1\\dist\\make WIN_CI_BUILD=1 test_gpt2 train_gpt2\n\n      - name: Execute testing program (With OpenMP)\n        if: matrix.os != 'windows-latest'\n        run: OMP_NUM_THREADS=8 ./test_gpt2\n\n      - name: Execute Windows testing program (With OpenMP)\n        if: matrix.os == 'windows-latest'\n        shell: cmd\n        run: |\n          copy test_gpt2 test_gpt2.exe\n          test_gpt2.exe\n\n      - name: Compile training and testing program without OpenMP\n        if: matrix.os != 'windows-latest'\n        run: NO_OMP=1 make test_gpt2 train_gpt2\n\n      - name: Execute testing program (No OpenMP)\n        if: matrix.os != 'windows-latest'\n        run: ./test_gpt2\n\n  build-cuda-windows:\n    runs-on: windows-latest\n    steps:\n    - name: Checkout code\n      uses: actions/checkout@v4\n\n    - name: Download Win32 Make.exe\n      run: |\n          $wc = New-Object System.Net.WebClient\n          $url = 'https://github.com/maweil/MakeForWindows/releases/download/v4.4.1/make-bin-win64.zip'\n          $output = './make-bin-win64.zip'\n          $wc.DownloadFile($url, $output)\n\n    - name: Unzip Win32 Makefile\n      run: |\n        unzip make-bin-win64.zip\n\n    - name: Install Cuda Toolkit 12.4 on Windows\n      run: |\n        mkdir -p \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\"\n        choco install unzip -y\n        curl -O \"https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-12.4.127-archive.zip\"\n        curl -O \"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-12.4.131-archive.zip\"\n        curl -O \"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-12.4.127-archive.zip\"\n        curl -O \"https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-12.4.5.8-archive.zip\"\n        curl -O \"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-12.4.127-archive.zip\"\n        curl -O \"https://developer.download.nvidia.com/compute/cuda/redist/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-12.4.127-archive.zip\"\n        curl -O \"https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-12.4.127-archive.zip\"\n        curl -O \"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-12.4.127-archive.zip\"\n        curl -O \"https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-12.4.127-archive.zip\"\n        unzip '*.zip' -d \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\"\n        xcopy \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\\cuda_cudart-windows-x86_64-12.4.127-archive\\*\" \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\" /E /I /H /Y\n        xcopy \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\\cuda_nvcc-windows-x86_64-12.4.131-archive\\*\" \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\" /E /I /H /Y\n        xcopy \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\\cuda_nvrtc-windows-x86_64-12.4.127-archive\\*\" \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\" /E /I /H /Y\n        xcopy \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\\libcublas-windows-x86_64-12.4.5.8-archive\\*\" \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\" /E /I /H /Y\n        xcopy \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\\cuda_nvtx-windows-x86_64-12.4.127-archive\\*\" \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\" /E /I /H /Y\n        xcopy \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\\cuda_profiler_api-windows-x86_64-12.4.127-archive\\*\" \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\" /E /I /H /Y\n        xcopy \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\\visual_studio_integration-windows-x86_64-12.4.127-archive\\*\" \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\" /E /I /H /Y\n        xcopy \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\\cuda_nvprof-windows-x86_64-12.4.127-archive\\*\" \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\" /E /I /H /Y\n        xcopy \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\\cuda_cccl-windows-x86_64-12.4.127-archive\\*\" \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\" /E /I /H /Y\n\n    # Default installation path for CUDA Toolkit is C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\n    - name: Add Path\n      run: |\n        echo \"C:\\\\Program Files\\\\NVIDIA GPU Computing Toolkit\\\\CUDA\\\\v12.4\\\\bin\" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append\n        echo \"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\\libnvvp\" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append\n        echo \"CUDA_PATH=C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8\n        echo \"CUDA_PATH_V12_4=C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8\n\n    - name: Build Cuda targets\n      shell: cmd\n      working-directory: ${{ github.workspace }}\n      run: |\n        call \"C:\\\\Program Files\\\\Microsoft Visual Studio\\\\2022\\\\Enterprise\\\\VC\\\\Auxiliary\\\\Build\\\\vcvars64.bat\"\n        make-4.4.1\\dist\\make -j WIN_CI_BUILD=1 train_gpt2fp32cu test_gpt2fp32cu test_gpt2cu train_gpt2cu profile_gpt2cu\n\n  build-ubuntu20-04:\n    runs-on: ubuntu-20.04\n    container:\n      image: nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: System Info\n        run: |\n          nvcc --version\n          g++ --version\n\n      - name: Install cudnn frontend\n        run: |\n          apt-get update && apt-get install -y git\n          git clone https://github.com/NVIDIA/cudnn-frontend.git\n\n      - name: Build FP32 checkpoint\n        run: make train_gpt2fp32cu test_gpt2fp32cu\n\n      - name: Build FP32 precision\n        run: PRECISION=FP32 make train_gpt2cu test_gpt2cu profile_gpt2cu\n\n      - name: Build with CUDNN\n        run: PRECISION=BF16 USE_CUDNN=1 make train_gpt2cu test_gpt2cu profile_gpt2cu\n\n  build-cuda-fp32:\n    runs-on: ubuntu-latest\n    container:\n      image: nvidia/cuda:12.4.1-devel-ubuntu22.04\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Build FP32 checkpoint\n        run: make train_gpt2fp32cu test_gpt2fp32cu\n\n      - name: Build FP32 precision\n        run: PRECISION=FP32 make train_gpt2cu test_gpt2cu profile_gpt2cu\n\n  build-cuda-bf16:\n    runs-on: ubuntu-latest\n    container:\n      image: nvidia/cuda:12.4.1-devel-ubuntu22.04\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Build project\n        run: PRECISION=BF16 make test_gpt2cu train_gpt2cu profile_gpt2cu\n\n  build-cuda-fp16:\n    runs-on: ubuntu-latest\n    container:\n      image: nvidia/cuda:12.4.1-devel-ubuntu22.04\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Build project\n        run: PRECISION=FP16 make test_gpt2cu train_gpt2cu profile_gpt2cu\n\n  build-cuda-kernels:\n    runs-on: ubuntu-latest\n    container:\n      image: nvidia/cuda:12.4.1-devel-ubuntu22.04\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Install OpenMP and OpenMPI\n        run: apt-get update && apt-get install -y libomp-dev libopenmpi-dev\n\n      - name: Build project\n        run: make -j4 -C dev/cuda\n"
  },
  {
    "path": ".github/workflows/ci_gpu.yml",
    "content": "name: GPU Builds and Tests\n\non:\n  create:\n  workflow_dispatch:\n  push:\n    branches:\n      - master\n  pull_request:\n    branches:\n      - master\n\njobs:\n  build-and-test-gpu:\n    runs-on: ubicloud-gpu-standard-1-latest\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Install OpenMP\n        run: sudo apt-get update && sudo apt-get install -y libomp-dev\n\n      - name: Install dependencies\n        run: pip install -r requirements.txt\n\n      - name: Run preprocessing\n        run: python dev/data/tinyshakespeare.py\n\n      - name: Train model\n        run: python train_gpt2.py\n\n      - name: Compile training and testing program\n        run: make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu\n\n      - name: Train model (With OpenMP)\n        run: OMP_NUM_THREADS=8 ./train_gpt2cu\n\n      - name: Train model (FP32) with gpt2_124M.bin\n        run: |\n          PRECISION=FP32 make train_gpt2cu\n          ./train_gpt2cu -b 1 -t 64 -d 256 -l 0.0001 -v 200 -s 200 -a 1 -x 10 -r 0 -f 0 -e \"gpt2_124M.bin\"\n\n      - name: Test for percent loss differential for FP32 \n        run: |\n          PRECISION=FP32 make train_gpt2cu\n          ./train_gpt2cu -b 1 -t 64 -d 256 -l 0.0001 -v 200 -s 200 -a 1 -x 10 -r 0 -f 0 -e \"gpt2_124M.bin\" > train_gpt2cu_fp32_precision.txt\n          python dev/loss_checker_ci.py -f train_gpt2cu_fp32_precision.txt -s 20 -e 28 -a 5.0\n\n      - name: Build FP32 precision\n        run: PRECISION=FP32 make test_gpt2cu profile_gpt2cu\n\n      - name: Run default\n        run: ./test_gpt2cu\n\n      - name: Run no recompute GeLU\n        run: ./test_gpt2cu -r 0\n\n      - name: Run recompute LN\n        run: ./test_gpt2cu -r 2\n\n      - name: Build BF16 precision\n        run: PRECISION=BF16 make train_gpt2cu test_gpt2cu profile_gpt2cu\n\n      - name: Run default\n        run: ./test_gpt2cu\n\n      - name: Run no recompute GeLU\n        run: ./test_gpt2cu -r 0\n\n      - name: Run no master weights\n        run: ./test_gpt2cu -w 0\n\n      - name: Run recompute LN\n        run: ./test_gpt2cu -r 2\n\n      - name: Train model fp32 (With OpenMP)\n        run: OMP_NUM_THREADS=8 ./train_gpt2fp32cu\n\n      - name: Execute testing program (With OpenMP)\n        run: OMP_NUM_THREADS=8 ./test_gpt2cu\n\n      - name: Execute testing program fp32 (With OpenMP)\n        run: OMP_NUM_THREADS=8 ./test_gpt2fp32cu\n\n      - name: Compile training and testing program without OpenMP\n        run: NO_OMP=1 make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu\n\n      - name: Train model (No OpenMP)\n        run: NO_OMP=1 ./train_gpt2cu\n\n      - name: Train model fp32 (No OpenMP)\n        run: NO_OMP=1 ./train_gpt2fp32cu\n\n      - name: Execute testing program (No OpenMP)\n        run: ./test_gpt2cu -b 32\n\n      - name: Execute testing program fp32 (No OpenMP)\n        run: ./test_gpt2fp32cu\n\n      - name: Install cuDNN-frontend\n        run:\n          git clone https://github.com/NVIDIA/cudnn-frontend.git\n\n      - name: Build with cuDNN\n        run: USE_CUDNN=1 make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu\n\n      - name: Train model with cuDNN\n        run: ./train_gpt2cu\n\n      - name: Train model fp32 with cuDNN\n        run: ./train_gpt2fp32cu\n\n      - name: Execute testing program with cuDNN\n        run: ./test_gpt2cu\n\n      - name: Execute testing program fp32 with cuDNN\n        run: ./test_gpt2fp32cu\n\n  unit-tests-gpu:\n    runs-on: ubicloud-gpu-standard-1-latest\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Test Device<->File IO\n        run: cd dev/test && nvcc -o device_file_io device_file_io.cu && ./device_file_io\n"
  },
  {
    "path": ".github/workflows/ci_tests.yml",
    "content": "name: Unit, Static and other Tests\n\non:\n  create:\n  workflow_dispatch:\n  push:\n    branches:\n      - master\n  pull_request:\n    branches:\n      - master\n\njobs:\n  dataloader_test:\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: test the dataloader without / with sanitize address\n        run: |\n          cd dev/test\n          make PRECISION=BF16 test_dataloader\n          ./test_dataloader   \n          make clean       \n          make PRECISION=BF16 TEST_CFLAGS=\"-fsanitize=address -fno-omit-frame-pointer\" test_dataloader \n          ./test_dataloader          \n\n  ptx_and_sass_files:\n    runs-on: ubuntu-latest\n    container:\n      image: nvidia/cuda:12.4.1-devel-ubuntu22.04\n\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Install OpenMP and OpenMPI\n        run: apt-get update && apt-get install -y libomp-dev libopenmpi-dev\n    \n      - name: Generate ptx/sass files and upload them to persistent storage\n        run: |\n          mkdir -p dev/cuda/ptx_sass_logs\n          make train_gpt2cu\n          cuobjdump --dump-ptx train_gpt2cu > dev/cuda/train_gpt2cu.ptx\n          cuobjdump --dump-sass train_gpt2cu > dev/cuda/train_gpt2cu.sass          \n          cd dev/cuda\n          make -j all_ptx\n          make -j all_sass\n          cp *.ptx ptx_sass_logs/\n          cp *.sass ptx_sass_logs/\n          ls ptx_sass_logs/\n\n      - name: Generate ptx/sass files for A100 and upload them to persistent storage\n        run: |\n            mkdir -p dev/cuda/ptx_sass_logs_A100\n            make train_gpt2cu GPU_COMPUTE_CAPABILITY=80\n            cuobjdump --dump-ptx train_gpt2cu > dev/cuda/train_gpt2cu.ptx\n            cuobjdump --dump-sass train_gpt2cu > dev/cuda/train_gpt2cu.sass          \n            cd dev/cuda\n            make -j GPU_COMPUTE_CAPABILITY=80 all_ptx\n            make -j GPU_COMPUTE_CAPABILITY=80 all_sass\n            cp *.ptx ptx_sass_logs_A100/\n            cp *.sass ptx_sass_logs_A100/\n            ls ptx_sass_logs_A100/\n\n      - name: Generate ptx/sass files for H100 and upload them to persistent storage\n        run: |\n            mkdir -p dev/cuda/ptx_sass_logs_H100\n            make train_gpt2cu GPU_COMPUTE_CAPABILITY=90\n            cuobjdump --dump-ptx train_gpt2cu > dev/cuda/train_gpt2cu.ptx\n            cuobjdump --dump-sass train_gpt2cu > dev/cuda/train_gpt2cu.sass          \n            cd dev/cuda\n            make -j GPU_COMPUTE_CAPABILITY=90 all_ptx\n            make -j GPU_COMPUTE_CAPABILITY=90 all_sass\n            cp *.ptx ptx_sass_logs_H100/\n            cp *.sass ptx_sass_logs_H100/\n            ls ptx_sass_logs_H100/\n\n      - name: Upload ptx/sass files\n        uses: actions/upload-artifact@v4\n        with:\n          name: ptx_sass_files\n          path: dev/cuda/ptx_sass_logs/\n          retention-days: 30 # days to retain\n\n      - name: Upload ptx/sass files for A100\n        uses: actions/upload-artifact@v4\n        with:\n          name: ptx_sass_files_A100\n          path: dev/cuda/ptx_sass_logs_A100/\n          retention-days: 30 # days to retain          \n\n      - name: Upload ptx/sass files for H100\n        uses: actions/upload-artifact@v4\n        with:\n          name: ptx_sass_files_H100\n          path: dev/cuda/ptx_sass_logs_H100/\n          retention-days: 30 # days to retain                    "
  },
  {
    "path": ".gitignore",
    "content": "# dot files and such\n.vscode\n.venv\n\n# .bin files generated by Python\n*.bin\n\n# data directories\ndev/data/__pycache__/\ndev/data/fineweb10B/\ndev/data/hellaswag/\ndev/data/mmlu/\ndev/data/tinyshakespeare/\ndev/data/tinystories/\n\n# binaries\ntest_gpt2\ntest_gpt2cu\ntest_gpt2fp32cu\ntrain_gpt2\ntrain_gpt2cu\ntrain_gpt2fp32cu\nprofile_gpt2cu\ndev/cuda/*_forward\ndev/cuda/*_backward\ndev/cuda/classifier_fused\ndev/cuda/adamw\ndev/cuda/matmul_backward_bias\ndev/cuda/nccl_all_reduce\ndev/cuda/global_norm\n*.obj\n*.exe\n*.o\n\n# log files\n*.log\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 Andrej Karpathy\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "Makefile",
    "content": "CC ?= clang\nCFLAGS = -Ofast -Wno-unused-result -Wno-ignored-pragmas -Wno-unknown-attributes\nLDFLAGS =\nLDLIBS = -lm\nINCLUDES =\nCFLAGS_COND = -march=native\n\n# Find nvcc\nSHELL_UNAME = $(shell uname)\nREMOVE_FILES = rm -f\nOUTPUT_FILE = -o $@\nCUDA_OUTPUT_FILE = -o $@\n\n# Default O3 CPU optimization level for NVCC (0 for fastest compile time)\nFORCE_NVCC_O ?= 3\n\n# NVCC flags\n# -t=0 is short for --threads, 0 = number of CPUs on the machine\nNVCC_FLAGS = --threads=0 -t=0 --use_fast_math -std=c++17 -O$(FORCE_NVCC_O)\nNVCC_LDFLAGS = -lcublas -lcublasLt\nNVCC_INCLUDES =\nNVCC_LDLIBS =\nNCLL_INCUDES =\nNVCC_CUDNN =\n# By default we don't build with cudnn because it blows up compile time from a few seconds to ~minute\nUSE_CUDNN ?= 0\n\n# We will place .o files in the `build` directory (create it if it doesn't exist)\nBUILD_DIR = build\nifeq ($(OS), Windows_NT)\n  $(shell if not exist $(BUILD_DIR) mkdir $(BUILD_DIR))\n  REMOVE_BUILD_OBJECT_FILES := del $(BUILD_DIR)\\*.obj\nelse\n  $(shell mkdir -p $(BUILD_DIR))\n  REMOVE_BUILD_OBJECT_FILES := rm -f $(BUILD_DIR)/*.o\nendif\n\n# Function to check if a file exists in the PATH\nifneq ($(OS), Windows_NT)\ndefine file_exists_in_path\n  $(which $(1) 2>/dev/null)\nendef\nelse\ndefine file_exists_in_path\n  $(shell where $(1) 2>nul)\nendef\nendif\n\nifneq ($(CI),true) # if not in CI, then use the GPU query\n  ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY=\n    ifneq ($(call file_exists_in_path, nvidia-smi),)\n      # Get the compute capabilities of all GPUs\n      # Remove decimal points, sort numerically in ascending order, and select the first (lowest) value\n      GPU_COMPUTE_CAPABILITY=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | sed 's/\\.//g' | sort -n | head -n 1)\n      GPU_COMPUTE_CAPABILITY := $(strip $(GPU_COMPUTE_CAPABILITY))\n    endif\n  endif\nendif\n\n# set to defaults if - make GPU_COMPUTE_CAPABILITY= otherwise use the compute capability detected above\nifneq ($(GPU_COMPUTE_CAPABILITY),)\n  NVCC_FLAGS += --generate-code arch=compute_$(GPU_COMPUTE_CAPABILITY),code=[compute_$(GPU_COMPUTE_CAPABILITY),sm_$(GPU_COMPUTE_CAPABILITY)]\nendif\n\n# autodect a lot of various supports on current platform\n$(info ---------------------------------------------)\n\nifneq ($(OS), Windows_NT)\n  NVCC := $(shell which nvcc 2>/dev/null)\n  NVCC_LDFLAGS += -lnvidia-ml\n\n  # Function to test if the compiler accepts a given flag.\n  define check_and_add_flag\n    $(eval FLAG_SUPPORTED := $(shell printf \"int main() { return 0; }\\n\" | $(CC) $(1) -x c - -o /dev/null 2>/dev/null && echo 'yes'))\n    ifeq ($(FLAG_SUPPORTED),yes)\n        CFLAGS += $(1)\n    endif\n  endef\n\n  # Check each flag and add it if supported\n  $(foreach flag,$(CFLAGS_COND),$(eval $(call check_and_add_flag,$(flag))))\nelse\n  CFLAGS :=\n  REMOVE_FILES = del *.exe,*.obj,*.lib,*.exp,*.pdb && del\n  SHELL_UNAME := Windows\n  ifneq ($(shell where nvcc 2> nul),\"\")\n    NVCC := nvcc\n  else\n    NVCC :=\n  endif\n  CC := cl\n  CFLAGS = /Idev /Zi /nologo /W4 /WX- /diagnostics:column /sdl /O2 /Oi /Ot /GL /D _DEBUG /D _CONSOLE /D _UNICODE /D UNICODE /Gm- /EHsc /MD /GS /Gy /fp:fast /Zc:wchar_t /Zc:forScope /Zc:inline /permissive- \\\n   /external:W3 /Gd /TP /wd4996 /Fd$@.pdb /FC /openmp:llvm\n  LDFLAGS :=\n  LDLIBS :=\n  INCLUDES :=\n  NVCC_FLAGS += -I\"dev\"\n  ifeq ($(WIN_CI_BUILD),1)\n    $(info Windows CI build)\n    OUTPUT_FILE = /link /OUT:$@\n    CUDA_OUTPUT_FILE = -o $@\n  else\n    $(info Windows local build)\n    OUTPUT_FILE = /link /OUT:$@ && copy /Y $@ $@.exe\n    CUDA_OUTPUT_FILE = -o $@ && copy /Y $@.exe $@\n  endif\nendif\n\n# Check and include cudnn if available\n# You can override the path to cudnn frontend by setting CUDNN_FRONTEND_PATH on the make command line\n# By default, we look for it in HOME/cudnn-frontend/include and ./cudnn-frontend/include\n# Refer to the README for cuDNN install instructions\nifeq ($(USE_CUDNN), 1)\n  ifeq ($(SHELL_UNAME), Linux)\n    ifeq ($(shell [ -d $(HOME)/cudnn-frontend/include ] && echo \"exists\"), exists)\n      $(info ✓ cuDNN found, will run with flash-attention)\n      CUDNN_FRONTEND_PATH ?= $(HOME)/cudnn-frontend/include\n    else ifeq ($(shell [ -d cudnn-frontend/include ] && echo \"exists\"), exists)\n      $(info ✓ cuDNN found, will run with flash-attention)\n      CUDNN_FRONTEND_PATH ?= cudnn-frontend/include\n    else\n      $(error ✗ cuDNN not found. See the README for install instructions and the Makefile for hard-coded paths)\n    endif\n    NVCC_INCLUDES += -I$(CUDNN_FRONTEND_PATH)\n    NVCC_LDFLAGS += -lcudnn\n    NVCC_FLAGS += -DENABLE_CUDNN\n    NVCC_CUDNN = $(BUILD_DIR)/cudnn_att.o\n  else\n    ifneq ($(OS), Windows_NT)\n      $(info → cuDNN is not supported on MAC OS right now)\n    else\n      $(info ✓ Windows cuDNN found, will run with flash-attention)\n      ifeq ($(shell if exist \"$(HOMEDRIVE)$(HOMEPATH)\\cudnn-frontend\\include\" (echo exists)),exists)\n        CUDNN_FRONTEND_PATH ?= $(HOMEDRIVE)$(HOMEPATH)\\cudnn-frontend\\include #override on command line if different location\n      else ifeq ($(shell if exist \"cudnn-frontend\\include\" (echo exists)),exists)\n        CUDNN_FRONTEND_PATH ?= cudnn-frontend\\include #override on command line if different location\n      else\n        $(error ✗ cuDNN not found. See the README for install instructions and the Makefile for hard-coded paths)\n      endif\n      CUDNN_INCLUDE_PATH ?= -I\"C:\\Program Files\\NVIDIA\\CUDNN\\v9.1\\include\\12.4\"\n      CUDNN_FRONTEND_PATH += $(CUDNN_INCLUDE_PATH)\n      NVCC_FLAGS += --std c++20 -Xcompiler \"/std:c++20\" -Xcompiler \"/EHsc /W0 /nologo /Ox /FS\" -maxrregcount=0 --machine 64\n      NVCC_CUDNN = $(BUILD_DIR)\\cudnn_att.obj\n      NVCC_INCLUDES += -I$(CUDNN_FRONTEND_PATH)\n      NVCC_LDFLAGS += -L\"C:\\Program Files\\NVIDIA\\CUDNN\\v9.1\\lib\\12.4\\x64\" -lcudnn\n      NVCC_FLAGS += -DENABLE_CUDNN\n    endif\n  endif\nelse\n  $(info → cuDNN is manually disabled by default, run make with `USE_CUDNN=1` to try to enable)\nendif\n\n# Check if OpenMP is available\n# This is done by attempting to compile an empty file with OpenMP flags\n# OpenMP makes the code a lot faster so I advise installing it\n# e.g. on MacOS: brew install libomp\n# e.g. on Ubuntu: sudo apt-get install libomp-dev\n# later, run the program by prepending the number of threads, e.g.: OMP_NUM_THREADS=8 ./gpt2\n# First, check if NO_OMP is set to 1, if not, proceed with the OpenMP checks\nifeq ($(NO_OMP), 1)\n  $(info OpenMP is manually disabled)\nelse\n  ifneq ($(OS), Windows_NT)\n  # Detect if running on macOS or Linux\n    ifeq ($(SHELL_UNAME), Darwin)\n      # Check for Homebrew's libomp installation in different common directories\n      ifeq ($(shell [ -d /opt/homebrew/opt/libomp/lib ] && echo \"exists\"), exists)\n        # macOS with Homebrew on ARM (Apple Silicon)\n        CFLAGS += -Xclang -fopenmp -DOMP\n        LDFLAGS += -L/opt/homebrew/opt/libomp/lib\n        LDLIBS += -lomp\n        INCLUDES += -I/opt/homebrew/opt/libomp/include\n        $(info ✓ OpenMP found)\n      else ifeq ($(shell [ -d /usr/local/opt/libomp/lib ] && echo \"exists\"), exists)\n        # macOS with Homebrew on Intel\n        CFLAGS += -Xclang -fopenmp -DOMP\n        LDFLAGS += -L/usr/local/opt/libomp/lib\n        LDLIBS += -lomp\n        INCLUDES += -I/usr/local/opt/libomp/include\n        $(info ✓ OpenMP found)\n      else\n        $(info ✗ OpenMP not found)\n      endif\n    else\n      # Check for OpenMP support in GCC or Clang on Linux\n      ifeq ($(shell echo | $(CC) -fopenmp -x c -E - > /dev/null 2>&1; echo $$?), 0)\n        CFLAGS += -fopenmp -DOMP\n        LDLIBS += -lgomp\n        $(info ✓ OpenMP found)\n      else\n        $(info ✗ OpenMP not found)\n      endif\n    endif\n  endif\nendif\n\n# Check if NCCL is available, include if so, for multi-GPU training\nifeq ($(NO_MULTI_GPU), 1)\n  $(info → Multi-GPU (NCCL) is manually disabled)\nelse\n  ifneq ($(OS), Windows_NT)\n    # Detect if running on macOS or Linux\n    ifeq ($(SHELL_UNAME), Darwin)\n      $(info ✗ Multi-GPU on CUDA on Darwin is not supported, skipping NCCL support)\n    else ifeq ($(shell dpkg -l | grep -q nccl && echo \"exists\"), exists)\n      $(info ✓ NCCL found, OK to train with multiple GPUs)\n      NVCC_FLAGS += -DMULTI_GPU\n      NVCC_LDLIBS += -lnccl\n    else\n      $(info ✗ NCCL is not found, disabling multi-GPU support)\n      $(info ---> On Linux you can try install NCCL with `sudo apt install libnccl2 libnccl-dev`)\n    endif\n  endif\nendif\n\n# Attempt to find and include OpenMPI on the system\nOPENMPI_DIR ?= /usr/lib/x86_64-linux-gnu/openmpi\nOPENMPI_LIB_PATH = $(OPENMPI_DIR)/lib/\nOPENMPI_INCLUDE_PATH = $(OPENMPI_DIR)/include/\nifeq ($(NO_USE_MPI), 1)\n  $(info → MPI is manually disabled)\nelse ifeq ($(shell [ -d $(OPENMPI_LIB_PATH) ] && [ -d $(OPENMPI_INCLUDE_PATH) ] && echo \"exists\"), exists)\n  $(info ✓ MPI enabled)\n  NVCC_INCLUDES += -I$(OPENMPI_INCLUDE_PATH)\n  NVCC_LDFLAGS += -L$(OPENMPI_LIB_PATH)\n  NVCC_LDLIBS += -lmpi\n  NVCC_FLAGS += -DUSE_MPI\nelse\n  $(info ✗ MPI not found)\nendif\n\n# Precision settings, default to bf16 but ability to override\nPRECISION ?= BF16\nVALID_PRECISIONS := FP32 FP16 BF16\nifeq ($(filter $(PRECISION),$(VALID_PRECISIONS)),)\n  $(error Invalid precision $(PRECISION), valid precisions are $(VALID_PRECISIONS))\nendif\nifeq ($(PRECISION), FP32)\n  PFLAGS = -DENABLE_FP32\nelse ifeq ($(PRECISION), FP16)\n  PFLAGS = -DENABLE_FP16\nelse\n  PFLAGS = -DENABLE_BF16\nendif\n\n# PHONY means these targets will always be executed\n.PHONY: all train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu profile_gpt2cu\n\n# Add targets\nTARGETS = train_gpt2 test_gpt2\n\n# Conditional inclusion of CUDA targets\nifeq ($(NVCC),)\n    $(info ✗ nvcc not found, skipping GPU/CUDA builds)\nelse\n    $(info ✓ nvcc found, including GPU/CUDA support)\n    TARGETS += train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu $(NVCC_CUDNN)\nendif\n\n$(info ---------------------------------------------)\n\nall: $(TARGETS)\n\ntrain_gpt2: train_gpt2.c\n\t$(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) $^ $(LDLIBS) $(OUTPUT_FILE)\n\ntest_gpt2: test_gpt2.c\n\t$(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) $^ $(LDLIBS) $(OUTPUT_FILE)\n\n$(NVCC_CUDNN): llmc/cudnn_att.cpp\n\t$(NVCC) -c $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_INCLUDES) -o $@\n\ntrain_gpt2cu: train_gpt2.cu $(NVCC_CUDNN)\n\t$(NVCC) $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE)\n\ntrain_gpt2fp32cu: train_gpt2_fp32.cu\n\t$(NVCC) $(NVCC_FLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE)\n\ntest_gpt2cu: test_gpt2.cu $(NVCC_CUDNN)\n\t$(NVCC) $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE)\n\ntest_gpt2fp32cu: test_gpt2_fp32.cu\n\t$(NVCC) $(NVCC_FLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE)\n\nprofile_gpt2cu: profile_gpt2.cu $(NVCC_CUDNN)\n\t$(NVCC) $(NVCC_FLAGS) $(PFLAGS) -lineinfo $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS)  $(CUDA_OUTPUT_FILE)\n\nclean:\n\t$(REMOVE_FILES) $(TARGETS)\n\t$(REMOVE_BUILD_OBJECT_FILES)\n"
  },
  {
    "path": "README.md",
    "content": "# llm.c\n\nLLMs in simple, pure C/CUDA with no need for 245MB of PyTorch or 107MB of cPython. Current focus is on pretraining, in particular reproducing the [GPT-2](https://github.com/openai/gpt-2) and [GPT-3](https://arxiv.org/abs/2005.14165) miniseries, along with a parallel PyTorch reference implementation in [train_gpt2.py](train_gpt2.py). You'll recognize this file as a slightly tweaked [nanoGPT](https://github.com/karpathy/nanoGPT), an earlier project of mine. Currently, llm.c is a bit faster than PyTorch Nightly (by about 7%). In addition to the bleeding edge mainline code in [train_gpt2.cu](train_gpt2.cu), we have a simple reference CPU fp32 implementation in ~1,000 lines of clean code in one file [train_gpt2.c](train_gpt2.c). I'd like this repo to only maintain C and CUDA code. Ports to other languages or repos are very welcome, but should be done in separate repos, and I am happy to link to them below in the \"notable forks\" section. Developer coordination happens in the [Discussions](https://github.com/karpathy/llm.c/discussions) and on Discord, either the `#llmc` channel on the [Zero to Hero](https://discord.gg/3zy8kqD9Cp) channel, or on `#llmdotc` on [GPU MODE](https://discord.gg/gpumode) Discord.\n\n## quick start\n\nThe best introduction to the llm.c repo today is reproducing the GPT-2 (124M) model. [Discussion #481](https://github.com/karpathy/llm.c/discussions/481) steps through this in detail. We can reproduce other models from the GPT-2 and GPT-3 series in both llm.c and in the parallel implementation of PyTorch. Have a look at the [scripts README](scripts/README.md).\n\ndebugging tip: when you run the `make` command to build the binary, modify it by replacing `-O3` with `-g` so you can step through the code in your favorite IDE (e.g. vscode).\n\n## quick start (1 GPU, fp32 only)\n\nIf you won't be training on multiple nodes, aren't interested in mixed precision, and are interested in learning CUDA, the fp32 (legacy) files might be of interest to you. These are files that were \"checkpointed\" early in the history of llm.c and frozen in time. They are simpler, more portable, and possibly easier to understand. Run the 1 GPU, fp32 code like this:\n\n```bash\nchmod u+x ./dev/download_starter_pack.sh\n./dev/download_starter_pack.sh\nmake train_gpt2fp32cu\n./train_gpt2fp32cu\n```\n\nThe download_starter_pack.sh script is a quick & easy way to get started and it downloads a bunch of .bin files that help get you off the ground. These contain: 1) the GPT-2 124M model saved in fp32, in bfloat16, 2) a \"debug state\" used in unit testing (a small batch of data, and target activations and gradients), 3) the GPT-2 tokenizer, and 3) the tokenized [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset. Alternatively, instead of running the .sh script, you can re-create these artifacts manually as follows:\n\n```bash\npip install -r requirements.txt\npython dev/data/tinyshakespeare.py\npython train_gpt2.py\n```\n\n## quick start (CPU)\n\nThe \"I am so GPU poor that I don't even have one GPU\" section. You can still enjoy seeing llm.c train! But you won't go too far. Just like the fp32 version above, the CPU version is an even earlier checkpoint in the history of llm.c, back when it was just a simple reference implementation in C. For example, instead of training from scratch, you can finetune a GPT-2 small (124M) to output Shakespeare-like text, as an example:\n\n```bash\nchmod u+x ./dev/download_starter_pack.sh\n./dev/download_starter_pack.sh\nmake train_gpt2\nOMP_NUM_THREADS=8 ./train_gpt2\n```\n\nIf you'd prefer to avoid running the starter pack script, then as mentioned in the previous section you can reproduce the exact same .bin files and artifacts by running `python dev/data/tinyshakespeare.py` and then `python train_gpt2.py`.\n\nThe above lines (1) download an already tokenized [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset and download the GPT-2 (124M) weights, (3) init from them in C and train for 40 steps on tineshakespeare with AdamW (using batch size 4, context length only 64), evaluate validation loss, and sample some text. Honestly, unless you have a beefy CPU (and can crank up the number of OMP threads in the launch command), you're not going to get that far on CPU training LLMs, but it might be a good demo/reference. The output looks like this on my MacBook Pro (Apple Silicon M3 Max):\n\n```\n[GPT-2]\nmax_seq_len: 1024\nvocab_size: 50257\nnum_layers: 12\nnum_heads: 12\nchannels: 768\nnum_parameters: 124439808\ntrain dataset num_batches: 1192\nval dataset num_batches: 128\nnum_activations: 73323776\nval loss 5.252026\nstep 0: train loss 5.356189 (took 1452.121000 ms)\nstep 1: train loss 4.301069 (took 1288.673000 ms)\nstep 2: train loss 4.623322 (took 1369.394000 ms)\nstep 3: train loss 4.600470 (took 1290.761000 ms)\n... (trunctated) ...\nstep 39: train loss 3.970751 (took 1323.779000 ms)\nval loss 4.107781\ngenerating:\n---\nCome Running Away,\nGreater conquer\nWith the Imperial blood\nthe heaviest host of the gods\ninto this wondrous world beyond.\nI will not back thee, for how sweet after birth\nNetflix against repounder,\nwill not\nflourish against the earlocks of\nAllay\n---\n```\n\n## datasets\n\nThe data files inside `/dev/data/(dataset).py` are responsible for downloading, tokenizing and saving the tokens to .bin files, readable easily from C. So for example when you run:\n\n```bash\npython dev/data/tinyshakespeare.py\n```\n\nWe download and tokenize the [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset. The output of this looks like this:\n\n```\nwriting 32,768 tokens to ./dev/data/tinyshakespeare/tiny_shakespeare_val.bin\nwriting 305,260 tokens to ./dev/data/tinyshakespeare/tiny_shakespeare_train.bin\n```\n\nThe .bin files contain a short header (1024 bytes) and then a stream of tokens in uint16, indicating the token ids with the GPT-2 tokenizer. More datasets are available in `/dev/data`.\n\n## test\n\nI am also attaching a simple unit test for making sure our C code agrees with the PyTorch code. On the CPU as an example, compile and run with:\n\n```bash\nmake test_gpt2\n./test_gpt2\n```\n\nThis now loads the `gpt2_124M_debug_state.bin` file that gets written by train_gpt2.py, runs a forward pass, compares the logits and loss with the PyTorch reference implementation, then it does 10 iterations of training with Adam and makes sure the losses match PyTorch. To test the GPU version we run:\n\n```bash\n# fp32 test (cudnn not supported)\nmake test_gpt2cu PRECISION=FP32 && ./test_gpt2cu\n# mixed precision cudnn test\nmake test_gpt2cu USE_CUDNN=1 && ./test_gpt2cu\n```\n\nThis tests both the fp32 path and the mixed precision path. The test should pass and print `overall okay: 1`.\n\n## tutorial\n\nI attached a very small tutorial here, in [doc/layernorm/layernorm.md](doc/layernorm/layernorm.md). It's a simple, step-by-step guide to implementing a single layer of the GPT-2 model, the layernorm layer. This is a good starting point to understand how the layers are implemented in C.\n\n**flash attention**. As of May 1, 2024 we use the Flash Attention from cuDNN. Because cuDNN bloats the compile time from a few seconds to ~minute and this code path is right now very new, this is disabled by default. You can enable it by compiling like this:\n\n```bash\nmake train_gpt2cu USE_CUDNN=1\n```\n\nThis will try to compile with cudnn and run it. You have to have cuDNN installed on your system. The [cuDNN installation instructions](https://developer.nvidia.com/cudnn) with apt-get will grab the default set of cuDNN packages. For a minimal setup, the cuDNN dev package is sufficient, e.g. on Ubuntu 22.04 for CUDA 12.x:\n\n```bash\nwget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb\nsudo dpkg -i cuda-keyring_1.1-1_all.deb\nsudo apt-get update\nsudo apt-get -y install libcudnn9-dev-cuda-12\n```\n\nOn top of this you need the [cuDNN frontend](https://github.com/NVIDIA/cudnn-frontend/tree/main), but this is just header files. Simply clone the repo to your disk. The Makefile currently looks for it in either your home directory or the current directory. If you have put it elsewhere, add `CUDNN_FRONTEND_PATH=/path/to/your/cudnn-frontend/include` to the `make` command-line.\n\n## multi-GPU training\n\nMake sure you install MPI and NCCL, e.g. on Linux:\n\n```bash\nsudo apt install openmpi-bin openmpi-doc libopenmpi-dev\n```\n\nFor NCCL follow the instructions from the [official website](https://developer.nvidia.com/nccl/nccl-download) (e.g. network installer)\n\nand then:\n\n```bash\nmake train_gpt2cu\nmpirun -np <number of GPUs> ./train_gpt2cu\n```\n\nor simply run one of our scripts under `./scripts/`.\n\n## multi-node training\n\nMake sure you've installed `NCCL` following instructions from [multi-GPU](#multi-gpu-training) section.\n\nThere are 3 ways we currently support that allow you to run multi-node training:\n1) Use OpenMPI to exchange nccl id and initialize NCCL. See e.g. `./scripts/multi_node/run_gpt2_124M_mpi.sh` script for details.\n2) Use shared file system to init NCCL. See `./scripts/multi_node/run_gpt2_124M_fs.sbatch` script for details.\n3) Use TCP sockets to init NCCL. See `./scripts/multi_node/run_gpt2_124M_tcp.sbatch` script for details.\n\nNote:\n* If you're running in a slurm environment and your slurm doesn't support PMIx (which we assume will be a common situation given that `slurm-wlm` dropped PMIx support) you will have to use FS (2) or TCP (3) approach. To test whether your slurm supports PMIx run: `srun --mpi=list` and see whether you get `pmix` in the output.\n* If you don't have slurm set up, you can kick off a multi-node run using `mpirun` - MPI (1).\n\nNone of these 3 methods is superior, we just offer you options so that you can run in your specific environment.\n\n## experiments / sweeps\n\nJust as an example process to sweep learning rates on a machine with 4 GPUs on TinyStories. Run a shell script `sweep.sh` (after you of course `chmod u+x sweep.sh`):\n\n```bash\n#!/bin/bash\n\nlearning_rates=(3e-5 1e-4 3e-4 1e-3)\n\nfor i in {0..3}; do\n    export CUDA_VISIBLE_DEVICES=$i\n    screen -dmS \"tr$i\" bash -c \"./train_gpt2cu -i data/TinyStories -v 250 -s 250 -g 144 -l ${learning_rates[$i]} -o stories$i.log\"\ndone\n\n# you can bring these down with\n# screen -ls | grep -E \"tr[0-3]\" | cut -d. -f1 | xargs -I {} screen -X -S {} quit\n```\n\nThis example opens up 4 screen sessions and runs the four commands with different LRs. This writes the log files `stories$i.log` with all the losses, which you can plot as you wish in Python. A quick example of how to parse and plot these logfiles is in [dev/vislog.ipynb](dev/vislog.ipynb).\n\n## repo\n\nA few more words on what I want this repo to be:\n\nFirst, I want `llm.c` to be a place for education. E.g. our `dev/cuda` folder is a place for a library of kernels for all the layers that are manually hand-written and very well documented, starting from very simple kernels all the way to more complex / faster kernels. If you have a new kernel with various different tradeoffs, please feel free to contribute it here.\n\nThat said, I also want `llm.c` to be very fast too, even practically useful to train networks. E.g. to start, we should be able to reproduce the big GPT-2 (1.6B) training run. This requires that we incorporate whatever fastest kernels there are, including the use of libraries such as cuBLAS, cuBLASLt, CUTLASS, cuDNN, etc. I also think doing so serves an educational purpose to establish an expert upper bound, and a unit of measurement, e.g. you could say that your manually written kernels are 80% of cuBLAS speed, etc. Then you can choose to do a super fast run, or you can choose to \"drag and drop\" whatever manual kernels you wish to use, and run with those.\n\nHowever, as a constraint, I want to keep the mainline `llm.c` in the root folder simple and readable. If there is a PR that e.g. improves performance by 2% but it \"costs\" 500 lines of complex C code, and maybe an exotic 3rd party dependency, I may reject the PR because the complexity is not worth it. As a concrete example - making cuBLAS for matmuls the default in the root training loop is a no-brainer: it makes the mainline code much faster, it is a single line of interpretable code, and it is a very common dependency. On the side of this, we can have manual implementations that can compete with cuBLAS in `dev/cuda`.\n\nLastly, I will be a lot more sensitive to complexity in the root folder of the project, which contains the main / default files of the project. In comparison, the `dev/` folder is a bit more of a scratch space for us to develop a library of kernels or classes and share useful or related or educational code, and some of this code could be ok to be (locally) complex.\n\n## notable forks\n\n- AMD support\n  - [llm.c](https://github.com/anthonix/llm.c) by @[anthonix](https://github.com/anthonix): support for AMD devices, such as the 7900 XTX\n\n- C#\n  - [llm.cs](https://github.com/azret/llm.cs) by @[azret](https://github.com/azret): a C# port of this project\n  - [Llm.cs](https://github.com/nietras/Llm.cs) by @[nietras](https://github.com/nietras): a C# port of this project with focus on easy to get started on any platform. Clone and run ✅\n\n- CUDA C++\n  - [llm.cpp](https://github.com/gevtushenko/llm.c) by @[gevtushenko](https://github.com/gevtushenko): a port of this project using the [CUDA C++ Core Libraries](https://github.com/NVIDIA/cccl)\n     - A presentation this fork was covered in [this lecture](https://www.youtube.com/watch?v=WiB_3Csfj_Q) in the [GPU MODE Discord Server](https://discord.gg/cudamode)\n\n- C++/CUDA\n  - [llm.cpp](https://github.com/zhangpiu/llm.cpp/tree/master/llmcpp) by @[zhangpiu](https://github.com/zhangpiu): a port of this project using the [Eigen](https://gitlab.com/libeigen/eigen), supporting CPU/CUDA.\n\n- WebGPU C++\n  - [gpu.cpp](https://github.com/AnswerDotAI/gpu.cpp) by @[austinvhuang](https://github.com/austinvhuang): a library for portable GPU compute in C++ using native WebGPU. Aims to be a general-purpose library, but also porting llm.c kernels to WGSL.\n  \n- C++\n  - [llm.cpp](https://github.com/GaoYusong/llm.cpp) by @[GaoYusong](https://github.com/GaoYusong): a port of this project featuring a C++ single-header [tinytorch.hpp](https://github.com/GaoYusong/llm.cpp/blob/main/tinytorch.hpp) library\n\n- Go\n  - [llm.go](https://github.com/joshcarp/llm.go) by @[joshcarp](https://github.com/joshcarp): a Go port of this project\n\n- Java\n  - [llm.java](https://github.com/harryjackson/llm.java) by @[harryjackson](https://github.com/harryjackson): a Java port of this project\n\n- Metal\n  - [llm.metal](https://github.com/regrettable-username/llm.metal) by @[regrettable-username](https://github.com/regrettable-username): LLM training in simple, raw C/Metal Shading Language\n\n- Mojo\n  - [llm.🔥](https://github.com/dorjeduck/llm.mojo) by @[dorjeduck](https://github.com/dorjeduck): a Mojo port of this project\n\n- OpenCL\n  - [llm.c](https://github.com/krrishnarraj/llm.c) by @[krrishnarraj](https://github.com/krrishnarraj): an OpenCL port of this project\n\n- Rust\n  -  [llm.rs](https://github.com/yijunyu/llm.rs) by @[Yijun Yu](https://github.com/yijunyu): a Rust rewrite with the aim to have same performance\n  -  [llm.rs](https://github.com/ToJen/llm.rs) by @[ToJen](https://github.com/ToJen): a Rust port of this project\n\n- Swift\n  - [llm.swift](https://github.com/otabuzzman/llm.swift) by @[otabuzzman](https://github.com/otabuzzman): a Swift port of this project\n\n- Zig\n  - [llm.zig](https://github.com/Saimirbaci/llm.zig) by @[saimirbaci](https://github.com/Saimirbaci): a Zig port of this project\n \n- Habana Gaudi2\n  - [llm.tpc](https://github.com/abhilash1910/llm.tpc) by @[abhilash1910](https://github.com/abhilash1910): a Habana Gaudi2 port of this project \n\n- Nim\n  - [llm.nim](https://github.com/Vindaar/llm.nim) by @[Vindaar](https://github.com/Vindaar): a Nim port of this project\n\n## discussions\n\nWays of organizing development:\n\n- Experiencing a concrete issue with the repo? Use [Issues](https://github.com/karpathy/llm.c/issues).\n- Have some code to contribute? Open a [PR](https://github.com/karpathy/llm.c/pulls)\n- Chat about the repo, ask questions, etc.? Look at [Discussions](https://github.com/karpathy/llm.c/discussions).\n- Something faster? I created a new `#llmc` channel on my [Zero to Hero Discord channel](https://discord.gg/3zy8kqD9Cp).\n\n## license\n\nMIT\n"
  },
  {
    "path": "dev/cpu/matmul_forward.c",
    "content": "/*\nCPU Kernels for matmul forward pass.\n*/\n\n// Compile Examples:\n//\n//      MSVC: cl.exe /O2 /fp:fast /Qvec-report:2 /I. /I ..\\..\\dev matmul_forward.c\n//            cl.exe /O2 /fp:fast /Qvec-report:2 /arch:AVX /I. /I ..\\..\\dev matmul_forward.c\n//            cl.exe /O2 /fp:fast /Qvec-report:2 /arch:AVX2 /I. /I ..\\..\\dev matmul_forward.c\n//\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <math.h>\n#include <time.h>\n#include <unistd.h>\n\n// ----------------------------------------------------------------------------\n// CPU code reference\n\nvoid matmul_forward_cpu(float* out,\n                    const float* inp, const float* weight, const float* bias,\n                    int B, int T, int C, int OC) {\n    // OC is short for \"output channels\"\n    // inp is (B,T,C), weight is (OC, C), bias is (OC)\n    // out will be (B,T,OC)\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            float* out_bt = out + b * T * OC + t * OC;\n            const float* inp_bt = inp + b * T * C + t * C;\n            for (int o = 0; o < OC; o++) {\n                float val = (bias != NULL) ? bias[o] : 0.0f;\n                const float* wrow = weight + o*C;\n                for (int i = 0; i < C; i++) {\n                    val += inp_bt[i] * wrow[i];\n                }\n                out_bt[o] = val;\n            }\n        }\n    }\n}\n\nvoid matmul_forward_ngc92(float* out,\n    const float* inp, const float* weight, const float* bias,\n    int B, int T, int C, int OC) {\n    // most of the running time is spent here and in matmul_backward\n    // OC is short for \"output channels\"\n    // inp is (B,T,C), weight is (OC, C), bias is (OC)\n    // out will be (B,T,OC)\n\n    // make sure the tiled loop will be correct, otherwise, fallback to slow version\n    #define LOOP_UNROLL 8\n\n    if (B * T % LOOP_UNROLL != 0) {\n        printf(\"MUST BE A MULTIPLE OF 8\"); // FIXME\n        return;\n    }\n\n    // collapse the B and T loops into one and turn it into a strided loop.\n    // then we can tile the inner loop, and reuse the loaded weight LOOP_UNROLL many times\n    // for significant speed-ups.\n    for (int obt = 0; obt < B * T; obt += LOOP_UNROLL) {\n        for (int o = 0; o < OC; o++) {\n            // keep LOOP_UNROLL many results in register, initialized by the bias term.\n            float result[LOOP_UNROLL];\n            for (int ibt = 0; ibt < LOOP_UNROLL; ++ibt) {\n                result[ibt] = (bias != NULL) ? bias[o] : 0.0f;\n            }\n\n            // inner loops. Because we do LOOP_UNROLL steps of inner bt, we can cache\n            // the value of weight[i + o * C] and reuse it.\n            // we compile with -Ofast, so the compiler will turn the inner loop into a bunch of FMAs\n            for (int i = 0; i < C; i++) {\n                float w = weight[i + o * C];\n                for (int ibt = 0; ibt < LOOP_UNROLL; ++ibt) {\n                    int bt = obt + ibt;\n                    result[ibt] += inp[bt * C + i] * w;\n                }\n            }\n\n            // write back results to main memory\n            for (int ibt = 0; ibt < LOOP_UNROLL; ++ibt) {\n                int bt = obt + ibt;\n                out[bt * OC + o] = result[ibt];\n            }\n        }\n    }\n}\n\n#define NUM_KERNELS 2\n\nvoid matmul_forward(int kernel_num,\n    float* out,\n    const float* inp, const float* weight, const float* bias,\n    int B, int T, int C, int OC) {\n\n    switch (kernel_num) {\n        case 0:\n            matmul_forward_cpu(out, inp, weight, bias, B, T, C, OC);\n            break;\n        case 1:\n            matmul_forward_ngc92(out, inp, weight, bias, B, T, C, OC);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n\nvoid validate_results_cpu(const float* device_result, const float* cpu_reference, const char* name, int num_elements, float tolerance);\nfloat* make_random_float(size_t N);\n\nint main(int argc, char **argv) {\n    srand(0);\n\n    int B = 8;\n    int T = 1024;\n    int C = 768;\n    int OC = 768 * 4; // expansion of 4, e.g. in the MLP\n    int RUNS = 4; // number of times to run a kernel for benchmarks\n\n    srand(137);\n\n    float* out = make_random_float(B * T * OC);\n    float* inp = make_random_float(B * T * C);\n    float* weight = make_random_float(OC * C);\n    float* bias = make_random_float(OC);\n\n    float* grad_out = make_random_float(B * T * OC);\n    float* grad_inp = make_random_float(B * T * C);\n    float* grad_weight = make_random_float(OC * C);\n    float* grad_bias = make_random_float(OC);\n\n    printf(\"> Calculating reference\\n\");\n    matmul_forward_cpu(out, inp, weight, bias, B, T, C, OC);\n\n    for (int kernel_num = 0; kernel_num < NUM_KERNELS; kernel_num++) {\n        printf(\"> Verifying kernel #%d\\n\", kernel_num);\n\n        srand(137);\n\n        float* kernel_out = make_random_float(B * T * OC);\n        float* kernel_inp = make_random_float(B * T * C);\n        float* kernel_weight = make_random_float(OC * C);\n        float* kernel_bias = make_random_float(OC);\n\n        matmul_forward(kernel_num, kernel_out, kernel_inp, kernel_weight, kernel_bias, B, T, C, OC);\n\n        validate_results_cpu(kernel_out, out, \"out\", B * T * OC, 1e-5);\n\n        free(kernel_out);\n        free(kernel_inp);\n        free(kernel_weight);\n        free(kernel_bias);\n    }\n\n    printf(\"All kernels passed! Starting benchmarks.\\n\\n\");\n\n    for (int kernel_num = 0; kernel_num < NUM_KERNELS; kernel_num++) {\n        printf(\"> Running kernel #%d\\n\", kernel_num);\n        struct timespec start, end;\n        clock_gettime(CLOCK_MONOTONIC, &start);\n\n        for (int i = 0; i < RUNS; i++) {\n            matmul_forward(kernel_num, out, inp, weight, bias, B, T, C, OC);\n        }\n\n        clock_gettime(CLOCK_MONOTONIC, &end);\n        double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;\n        printf(\"> Kernel #%d, (took %f ms)\\n\", kernel_num, time_elapsed_s * 1000);\n    }\n\n    // free memory\n    free(out);\n    free(inp);\n    free(weight);\n    free(bias);\n\n    free(grad_out);\n    free(grad_inp);\n    free(grad_weight);\n    free(grad_bias);\n\n    return 0;\n}\n\nfloat* make_random_float(size_t N) {\n    float* arr = (float*)malloc(N * sizeof(float));\n    for (size_t i = 0; i < N; i++) {\n        arr[i] = ((float)rand() / RAND_MAX) * 2.0 - 1.0; // range -1..1\n    }\n    return arr;\n}\n\nvoid validate_results_cpu(const float* kernel_result, const float* cpu_reference, const char* name, int num_elements, float tolerance) {\n    int nfaults = 0;\n    for (int i = 0; i < num_elements; i++) {\n        // print the first few comparisons\n        if (i < 5) {\n            printf(\"%f %f\\n\", cpu_reference[i], kernel_result[i]);\n        }\n        float t_eff = tolerance + fabs(cpu_reference[i]);\n        // ensure correctness for all elements.\n        if (fabs(cpu_reference[i] - kernel_result[i]) > t_eff) {\n            printf(\"Mismatch of %s at %d: CPU_ref: %f vs CPU_new: %f\\n\", name, i, cpu_reference[i], kernel_result[i]);\n            nfaults++;\n            if (nfaults >= 10) {\n                exit(EXIT_FAILURE);\n            }\n        }\n    }\n    if (nfaults > 0) {\n        exit(EXIT_FAILURE);\n    }\n    printf(\"OK\\n\");\n}"
  },
  {
    "path": "dev/cuda/Makefile",
    "content": "# Makefile for building dev/cuda kernels\n# Collects all the make commands in one file but each file also\n# has the compile and run commands in the header comments section.\n\n# Find nvcc (NVIDIA CUDA compiler)\nNVCC := $(shell which nvcc 2>/dev/null)\nifeq ($(NVCC),)\n\t\t$(error nvcc not found.)\nendif\n\nifneq ($(CI),true) # if not in CI, then use the GPU query\n  ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY=\n    GPU_COMPUTE_CAPABILITY = $(shell __nvcc_device_query) # assume if NVCC is present, then this likely is too\n    GPU_COMPUTE_CAPABILITY := $(strip $(GPU_COMPUTE_CAPABILITY))\n  endif\nendif\n\n# Compiler flags\nifeq ($(GPU_COMPUTE_CAPABILITY),) # set to defaults if: make GPU_COMPUTE_CAPABILITY=\n  CFLAGS = -O3 --use_fast_math\nelse\n  CFLAGS = -O3 --use_fast_math --generate-code arch=compute_$(GPU_COMPUTE_CAPABILITY),code=[compute_$(GPU_COMPUTE_CAPABILITY),sm_$(GPU_COMPUTE_CAPABILITY)]\nendif\n\nNVCCFLAGS = -lcublas -lcublasLt -std=c++17\nMPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib/\n\n# Default rule for our CUDA files\n%: %.cu\n\t$(NVCC) $(CFLAGS) $(NVCCFLAGS) $< -o $@\n\n# Build all targets\nTARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward  global_norm permute\n\nall: $(TARGETS)\nall_ptx:  $(TARGETS:%=%.ptx)\nall_sass: $(TARGETS:%=%.sass)\n\n# Individual targets: forward pass\nattention_forward: attention_forward.cu\nclassifier_fused: classifier_fused.cu\ncrossentropy_forward: crossentropy_forward.cu\nencoder_forward: encoder_forward.cu\ngelu_forward: gelu_forward.cu\nlayernorm_forward: layernorm_forward.cu\nfused_residual_forward: fused_residual_forward.cu\nresidual_forward: residual_forward.cu\nsoftmax_forward: softmax_forward.cu\ntrimat_forward: trimat_forward.cu\n# matmul fwd/bwd also uses OpenMP (optionally) and cuBLASLt libs\nmatmul_forward: matmul_forward.cu\n\t$(NVCC) $(CFLAGS) $(NVCCFLAGS) -Xcompiler -fopenmp matmul_forward.cu -o matmul_forward\n\n# Individual targets: backward pass\nattention_backward: attention_backward.cu\ncrossentropy_softmax_backward: crossentropy_softmax_backward.cu\nencoder_backward: encoder_backward.cu\ngelu_backward: gelu_backward.cu\nlayernorm_backward: layernorm_backward.cu\nmatmul_backward_bias: matmul_backward_bias.cu\nmatmul_backward: matmul_backward.cu\n\t$(NVCC) $(CFLAGS) $(NVCCFLAGS) -Xcompiler -fopenmp matmul_backward.cu -o matmul_backward\n\n# Update kernels\nadamw: adamw.cu\nglobal_norm: global_norm.cu\n\npermute: permute.cu\n\n# NCCL communication kernels\nnccl_all_reduce: nccl_all_reduce.cu\n\t$(NVCC) -lmpi -lnccl $(NVCCFLAGS) $(MPI_PATHS) nccl_all_reduce.cu -o nccl_all_reduce\n\n# Generate PTX using cuobjdump\n%.ptx: %\n\tcuobjdump --dump-ptx $< > $@\n\n# Generate SASS using cuobjdump\n%.sass: %\n\tcuobjdump --dump-sass $< > $@\n\n# Run all targets\nrun_all: all\n\t@for target in $(TARGETS); do \\\n\t\techo \"\\n========================================\"; \\\n\t\techo \"Running $$target ...\"; \\\n\t\techo \"========================================\\n\"; \\\n\t\t./$$target; \\\n\tdone\n\n# Clean up\nclean:\n\trm -f $(TARGETS) *.ptx *.sass\n"
  },
  {
    "path": "dev/cuda/README.md",
    "content": "# dev/cuda\n\nThis directory is scratch space for developing various versions of the needed CUDA kernels. Each file develops a kernel, and usually multiple versions of that kernel that could have different running times and of different code or time complexity.\n\nSee the top of each file for how to compile and run the kernel. Alternatively, the commands are also all grouped in the `Makefile` in this directory for convenience.\n\nFor example, we can look at the top of `layernorm_forward.cu` to build the forward pass kernels for the LayerNorm:\n\n```bash\nnvcc -O3 --use_fast_math -lcublas -lcublasLt layernorm_forward.cu -o layernorm_forward\n```\n\nor simply\n\n```bash\nmake layernorm_forward\n```\n\nThe comments at the top then document the different versions of this kernel available, usually these are in increasing complexity and decreasing running times. For example, inspecting the comments in the file on top, the most naive kernel we can then run as:\n\n```bash\n./layernorm_forward 1\n```\n\nYou'll see that this first forwards the reference code on the CPU, then it runs kernel 1 on the GPU, compares the results to check for correctness, and then runs a number of configurations of this kernel (most often and most notably the block size), to time the kernel in these launch configurations. We can then run one of the faster kernels (kernel 4) instead:\n\n```bash\n./layernorm_forward 4\n```\n\nYou'll see that this matches all the CPU results but runs much much faster. The typical process from here on is we copy paste the kernel that ran fastest, adjust it manually (e.g. to hardcode the best block size) and drop it into the training code file, e.g. `train_gpt2.cu`.\n\nTo add a new version of a kernel, add the kernel to the corresponding file and adjust the docs. To add a new kernel, add the new file and adjust the Makefile. Run `make clean` to clean up binaries from your directory.\n\nIf you do not have a GPU or is having trouble with CUDA dependencies, you can run the benchmarks on the [Modal platform](http://modal.com). For example, to run the benchmark for the attention forward pass on an A100 GPU with 80GB of memory, you can run the following command:\n\n```bash\nGPU_MEM=80 modal run benchmark_on_modal.py --compile-command \"nvcc -O3 --use_fast_math attention_forward.cu -o attention_forward -lcublas\" --run-command \"./attention_forward 1\"\n```\n"
  },
  {
    "path": "dev/cuda/adamw.cu",
    "content": "/*\nKernels for the AdamW optimizer.\n\nReferences:\n  * https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html\n  * https://github.com/nvidia/apex/blob/master/csrc/multi_tensor_adam.cu\n\nCompile example:\nnvcc -lcublas -lcublasLt adamw.cu -o adamw\nnvcc -O3 --use_fast_math -lcublas -lcublasLt adamw.cu -o adamw\n\n./adamw\n\nTODO(general):\namsgrad=True\n\nTODO(perf):\ndtype\nthread coarsening/ILP\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <time.h>\n#include <cuda_runtime.h>\n#include \"common.h\"\n\n\n// ----------------------------------------------------------------------------\n// CPU code reference\n\nvoid adamw_cpu(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, int t, long num_parameters, float learning_rate=1e-3, float beta1=0.9, float beta2=0.999, float eps=1e-8, float weight_decay=0.0) {\n    // adapted from: train_gpt2.c\n\n    for (int i = 0; i < num_parameters; i++) {\n        float param = params_memory[i];\n        float grad = grads_memory[i];\n\n        // update the first moment (momentum)\n        float m = beta1 * m_memory[i] + (1.0f - beta1) * grad;\n        // update the second moment (RMSprop)\n        float v = beta2 * v_memory[i] + (1.0f - beta2) * grad * grad;\n        // bias-correct both moments\n        float m_hat = m / (1.0f - powf(beta1, t));\n        float v_hat = v / (1.0f - powf(beta2, t));\n\n        // update\n        m_memory[i] = m;\n        v_memory[i] = v;\n        params_memory[i] -= learning_rate * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * param);\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n\n// utility functions\n\n// Implements linear interpolation using only two floating-point operations (as opposed to three in a naive implementation).\n// Reference: https://developer.nvidia.com/blog/lerp-faster-cuda\n__device__ inline float lerp(float start, float end, float weight) {\n    return fma(weight, end, fma(-weight, start, start));\n}\n\n// naive fused kernel\n__global__ void adamw_kernel1(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, long num_parameters,\n                              float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) {\n   int i = blockIdx.x * blockDim.x + threadIdx.x;\n   if (i >= num_parameters) return;  // guard\n   // update the first moment (momentum)\n   m_memory[i] = beta1 * m_memory[i] + (1.0f - beta1) * grads_memory[i];\n   // update the second moment (RMSprop)\n   v_memory[i] = beta2 * v_memory[i] + (1.0f - beta2) * grads_memory[i] * grads_memory[i];\n   float m_hat = m_memory[i] / beta1_correction;\n   float v_hat = v_memory[i] / beta2_correction;\n   params_memory[i] -= learning_rate * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * params_memory[i]);\n}\n\n// Slightly more optimized AdamW kernel by:\n// * loading data that is accessed more than once into registers,\n// * using optimized linear interpolation for the moment updates.\n__global__ void adamw_kernel2(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, long num_parameters,\n                              float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) {\n   int i = blockIdx.x * blockDim.x + threadIdx.x;\n   if (i >= num_parameters) return;  // guard\n   float grad = grads_memory[i];\n   float m = m_memory[i];\n   float v = v_memory[i];\n   // update the first moment (momentum)\n   m = lerp(grad, m, beta1);\n   m_memory[i] = m;\n   // update the second moment (RMSprop)\n   v = lerp(grad * grad, v, beta2);\n   v_memory[i] = v;\n   m /= beta1_correction;  // m_hat\n   v /= beta2_correction;  // v_hat\n   params_memory[i] -= learning_rate * (m / (sqrtf(v) + eps) + weight_decay * params_memory[i]);\n}\n\n\n// ----------------------------------------------------------------------------\n// kernel launcher\n\n// version 1: naive dispatch to naive kernel\nvoid adamw_dispatch1(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, long num_parameters,\n                     float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) {\n    unsigned int block_size = 512;\n    unsigned int num_blocks = ceil_div(num_parameters, (long) block_size);\n    adamw_kernel1<<<num_blocks, block_size>>>(params_memory, grads_memory, m_memory, v_memory, num_parameters,\n                                              learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay);\n    cudaCheck(cudaGetLastError());\n}\n\n// version 2: naive dispatch to slightly optimized kernel\nvoid adamw_dispatch2(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, long num_parameters,\n                     float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) {\n    unsigned int block_size = 512;\n    unsigned int num_blocks = ceil_div(num_parameters, (long) block_size);\n    adamw_kernel2<<<num_blocks, block_size>>>(params_memory, grads_memory, m_memory, v_memory, num_parameters,\n                                              learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid adamw(int kernel_num,\n           float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, int t, long num_parameters,\n           float learning_rate=1e-3, float beta1=0.9, float beta2=0.999, float eps=1e-8, float weight_decay=0.0) {\n    // calculate the m_hat and v_hat correction terms once as they are the same for every param/thread\n    float beta1_correction = 1.0f - powf(beta1, t);\n    float beta2_correction = 1.0f - powf(beta2, t);\n    switch (kernel_num) {\n        case 1:\n            adamw_dispatch1(params_memory, grads_memory, m_memory, v_memory, num_parameters,\n                            learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay);\n            break;\n        case 2:\n            adamw_dispatch2(params_memory, grads_memory, m_memory, v_memory, num_parameters,\n                            learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n// ----------------------------------------------------------------------------\n\nint main(int argc, char **argv) {\n    setup_main();\n\n    const long num_parameters = 1048576;\n    const int t = 10;\n\n    const float learning_rate = 1e-3f;\n    const float beta1 = 0.9f;\n    const float beta2 = 0.999f;\n    const float eps = 1e-8f;\n    const float weight_decay = 0.0f;\n\n    // create random data on host (to be used for the CPU reference implementation)\n    float* params_memory = make_random_float(num_parameters);\n    float* grads_memory = make_random_float(num_parameters);\n    float* m_memory = make_random_float(num_parameters);\n    float* v_memory = make_random_float_01(num_parameters);\n\n    // move to GPU\n    float* d_params_memory;\n    float* d_grads_memory;\n    float* d_m_memory;\n    float* d_v_memory;\n    cudaCheck(cudaMalloc(&d_params_memory, num_parameters * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_grads_memory, num_parameters * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_m_memory, num_parameters * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_v_memory, num_parameters * sizeof(float)));\n    cudaCheck(cudaMemcpy(d_params_memory, params_memory, num_parameters * sizeof(float), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_grads_memory, grads_memory, num_parameters * sizeof(float), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_m_memory, m_memory, num_parameters * sizeof(float), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_v_memory, v_memory, num_parameters * sizeof(float), cudaMemcpyHostToDevice));\n\n\n    // read kernel_num from command line\n    int kernel_num = 1;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    // calculate the CPU reference (using default hyperparams)\n    clock_t start = clock();\n    adamw_cpu(params_memory, grads_memory, m_memory, v_memory, t, num_parameters);\n    clock_t end = clock();\n    // TODO: measure runtime with multiple runs\n    double elapsed_time_cpu = (double)(end - start) / CLOCKS_PER_SEC;\n\n    // calculate the GPU version (using default hyperparams)\n    adamw(kernel_num, d_params_memory, d_grads_memory, d_m_memory, d_v_memory, t, num_parameters);\n\n    // compare\n    printf(\"Checking correctness...\\n\");\n    printf(\"parameters:\\n\");\n    validate_result(d_params_memory, params_memory, \"params_memory\", num_parameters);\n    printf(\"first moment:\\n\");\n    validate_result(d_m_memory, m_memory, \"m_memory\", num_parameters);\n    printf(\"second moment:\\n\");\n    validate_result(d_v_memory, v_memory, \"v_memory\", num_parameters);\n    printf(\"All results match.\\n\\n\");\n\n    // now benchmark the kernel\n    int repeat_times = 1000;\n    float elapsed_time = benchmark_kernel(repeat_times, adamw, kernel_num,\n      d_params_memory, d_grads_memory, d_m_memory, d_v_memory, t, num_parameters,\n      learning_rate, beta1, beta2, eps, weight_decay);\n    printf(\"time gpu %.4f ms\\n\", elapsed_time);\n    printf(\"time cpu %.4f ms\\n\", elapsed_time_cpu);\n\n    // cleanup\n    free(params_memory);\n    free(grads_memory);\n    free(m_memory);\n    free(v_memory);\n    cudaCheck(cudaFree(d_params_memory));\n    cudaCheck(cudaFree(d_grads_memory));\n    cudaCheck(cudaFree(d_m_memory));\n    cudaCheck(cudaFree(d_v_memory));\n\n    return 0;\n}\n"
  },
  {
    "path": "dev/cuda/attention_backward.cu",
    "content": "/*\nKernels for attention backward pass.\n\nCompile example:\nnvcc -O3 --use_fast_math -lcublas -lcublasLt attention_backward.cu -o attention_backward\n\nversion 1 is a naive first version\nOMP_NUM_THREADS=32 ./attention_backward 1\n\nversion 2 much ensures better load-balancing by having independent threads for each batch and attention head\nOMP_NUM_THREADS=32 ./attention_backward 2\n\nversion 3 uses a full warp to calculate each result (instead of a thread), which enables coalesced memory access\nOMP_NUM_THREADS=32 ./attention_backward 3\n\nversion 4 improves data reuse in registers by doing 8 values of t3 in one warp.\nOMP_NUM_THREADS=32 ./attention_backward 4\n\nversion 5 reduces the amount of non-fp32 instructions needed by avoiding ifs\nOMP_NUM_THREADS=32 ./attention_backward 5\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <assert.h>\n#include <float.h>\n#include <cublas_v2.h>\n#include <cuda_runtime.h>\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n#include <cooperative_groups/scan.h>\n#include \"common.h\"\n\n// ----------------------------------------------------------------------------\n// CPU code reference\n\n/*\nNOTE:\nThis version of attention_forward is modified to be consistent with the\nattention_forward GPU kernel in the following way small but important way:\n- preatt is only QUERY @ KEY, without the scale\n- the scale instead moved and fused into the softmax\n- the full preatt matrix is materialized, even the parts that get masked out\n    - this doesn't actually change anything due to masking, but it lets us\n      easily compare to the GPU version, which also does the full, dense sgemm\nIn this way we'll be able to make sure that preatt and att agree CPU vs GPU\n*/\nvoid attention_forward_cpu(float* out, float* preatt, float* att,\n                            float* inp,\n                            int B, int T, int C, int NH) {\n    // input is (B, T, 3C) holding the query, key, value (Q, K, V) vectors\n    // preatt, att are (B, NH, T, T). NH = number of heads, T = sequence length\n    // that holds the pre-attention and post-attention scores (used in backward)\n    // output is (B, T, C)\n    // attention is the only layer that mixes information across time\n    // every other operation is applied at every (b,t) position independently\n    // (and of course, no layer mixes information across batch)\n    int C3 = C*3;\n    int hs = C / NH; // head size\n    float scale = 1.0 / sqrtf(hs);\n\n    #pragma omp parallel for collapse(3)\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            for (int h = 0; h < NH; h++) {\n                float* query_t = inp + b * T * C3 + t * C3 + h * hs;\n                float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T;\n                float* att_bth = att + b*NH*T*T + h*T*T + t*T;\n\n                // pass 1: calculate query dot key and maxval\n                float maxval = -FLT_MAX;\n                for (int t2 = 0; t2 < T; t2++) { // used to be t2 <= t\n                    float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key\n\n                    // (query_t) dot (key_t2)\n                    float val = 0.0f;\n                    for (int i = 0; i < hs; i++) {\n                        val += query_t[i] * key_t2[i];\n                    }\n                    if (val > maxval) {\n                        maxval = val;\n                    }\n\n                    preatt_bth[t2] = val;\n                }\n\n                // pass 2: calculate the exp and keep track of sum\n                // maxval is being calculated and subtracted only for numerical stability\n                float expsum = 0.0f;\n                for (int t2 = 0; t2 <= t; t2++) {\n                    float expv = expf(scale * (preatt_bth[t2] - maxval));\n                    expsum += expv;\n                    att_bth[t2] = expv;\n                }\n                float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum;\n\n                // pass 3: normalize to get the softmax\n                for (int t2 = 0; t2 < T; t2++) {\n                    if (t2 <= t) {\n                        att_bth[t2] *= expsum_inv;\n                    } else {\n                        // causal attention mask. not strictly necessary to set to zero here\n                        // only doing this explicitly for debugging and checking to PyTorch\n                        att_bth[t2] = 0.0f;\n                    }\n                }\n\n                // pass 4: accumulate weighted values into the output of attention\n                float* out_bth = out + b * T * C + t * C + h * hs;\n                for (int i = 0; i < hs; i++) { out_bth[i] = 0.0f; }\n                for (int t2 = 0; t2 <= t; t2++) {\n                    float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value\n                    float att_btht2 = att_bth[t2];\n                    for (int i = 0; i < hs; i++) {\n                        out_bth[i] += att_btht2 * value_t2[i];\n                    }\n                }\n            }\n        }\n    }\n}\n\n// NOTE: Also contains the re-shuffling of the exact position of \"scale\"\n// and when it is applied (after preatt, not \"during\" preatt)\n// also, full matrices are materialized, even the parts that get masked out\nvoid attention_backward_cpu(float* dinp, float* dpreatt, float* datt,\n                            float* dout, float* inp, float* att,\n                            int B, int T, int C, int NH) {\n    // inp/dinp are (B, T, 3C) Q,K,V\n    // att/datt/dpreatt are (B, NH, T, T)\n    // dout is (B, T, C)\n    int C3 = C*3;\n    int hs = C / NH; // head size\n    float scale = 1.0 / sqrtf(hs);\n\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            for (int h = 0; h < NH; h++) {\n                float* att_bth = att + b*NH*T*T + h*T*T + t*T;\n                float* datt_bth = datt + b*NH*T*T + h*T*T + t*T;\n                float* dpreatt_bth = dpreatt + b*NH*T*T + h*T*T + t*T;\n                float* dquery_t = dinp + b * T * C3 + t * C3 + h * hs;\n                float* query_t = inp + b * T * C3 + t * C3 + h * hs;\n\n                // backward pass 4, through the value accumulation\n                float* dout_bth = dout + b * T * C + t * C + h * hs;\n                for (int t2 = 0; t2 < T; t2++) { // ADJUSTED! this was t2 <= t (see note on function)\n                    float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value\n                    float* dvalue_t2 = dinp + b * T * C3 + t2 * C3 + h * hs + C*2;\n                    for (int i = 0; i < hs; i++) {\n                        // in the forward pass this was:\n                        // out_bth[i] += att_bth[t2] * value_t2[i];\n                        // so now we have:\n                        datt_bth[t2] += value_t2[i] * dout_bth[i];\n                        dvalue_t2[i] += att_bth[t2] * dout_bth[i];\n                    }\n                }\n\n                // backward pass 2 & 3, the softmax\n                // note that softmax (like e.g. tanh) doesn't need the input (preatt) to backward\n                for (int t2 = 0; t2 <= t; t2++) {\n                    for (int t3 = 0; t3 <= t; t3++) {\n                        float indicator = t2 == t3 ? 1.0f : 0.0f;\n                        float local_derivative = att_bth[t2] * (indicator - att_bth[t3]);\n                        dpreatt_bth[t3] += scale * local_derivative * datt_bth[t2];\n                    }\n                }\n\n                // backward pass 1, the query @ key matmul\n                for (int t2 = 0; t2 <= t; t2++) {\n                    float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key\n                    float* dkey_t2 = dinp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key\n                    for (int i = 0; i < hs; i++) {\n                        // in the forward pass this was:\n                        // preatt_bth[t2] += query_t[i] * key_t2[i]\n                        // so now we have:\n                        dquery_t[i] += key_t2[i] * dpreatt_bth[t2];\n                        dkey_t2[i] += query_t[i] * dpreatt_bth[t2];\n                    }\n                }\n            }\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n// the forward pass that is the sequence [permute, sgemm, softmax, sgemm, unpermute]\n\n__global__ void permute_kernel(float* q, float* k, float* v,\n                               const float* inp,\n                               int B, int N, int NH, int d) {\n    // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d)\n    // but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d)\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n    // Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_]\n    if (idx < B * NH * N * d) {\n        int b = idx / (NH * N * d);\n        int rest = idx % (NH * N * d);\n        int nh_ = rest / (N * d);\n        rest = rest % (N * d);\n        int n = rest / d;\n        int d_ = rest % d;\n\n        int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_;\n        q[idx] = inp[inp_idx];\n        k[idx] = inp[inp_idx + NH * d];\n        v[idx] = inp[inp_idx + 2 * (NH * d)];\n    }\n}\n\n__global__ void permute_kernel_backward(float* dinp,\n                                        const float* dq, const float* dk, const float* dv,\n                                        int B, int N, int NH, int d) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < B * NH * N * d) {\n        int b = idx / (NH * N * d);\n        int rest = idx % (NH * N * d);\n        int nh_ = rest / (N * d);\n        rest = rest % (N * d);\n        int n = rest / d;\n        int d_ = rest % d;\n\n        int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_;\n        dinp[inp_idx] += dq[idx];\n        dinp[inp_idx + NH * d] += dk[idx];\n        dinp[inp_idx + 2 * (NH * d)] += dv[idx];\n    }\n}\n\n__global__ void unpermute_kernel(const float* inp, float *out, int B, int N, int NH, int d) {\n   // out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d)\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n    // out[b][n][nh_][d_] <- inp[b][nh_][n][d_]\n    if (idx < B * NH * N * d) {\n        int b = idx / (NH * N * d);\n        int rest = idx % (NH * N * d);\n        int nh_ = rest / (N * d);\n        rest = rest % (N * d);\n        int n = rest / d;\n        int d_ = rest % d;\n\n        int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;\n        out[other_idx] = inp[idx];\n    }\n}\n\n__global__ void unpermute_kernel_backward(float* dinp, const float *dout, int B, int N, int NH, int d) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < B * NH * N * d) {\n        int b = idx / (NH * N * d);\n        int rest = idx % (NH * N * d);\n        int nh_ = rest / (N * d);\n        rest = rest % (N * d);\n        int n = rest / d;\n        int d_ = rest % d;\n\n        int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;\n        dinp[idx] += dout[other_idx];\n    }\n}\n\n__device__ float& vec_at(float4& vec, int index) {\n    return reinterpret_cast<float*>(&vec)[index];\n}\n\n__device__ float vec_at(const float4& vec, int index) {\n    return reinterpret_cast<const float*>(&vec)[index];\n}\n\n__global__ void softmax_forward_kernel5(float* out, float inv_temperature, const float* inp, int N, int T) {\n    // inp, out shape: (N, T, T), where N = B * NH\n    // fuses the multiplication by scale inside attention\n    // directly autoregressive, so we only compute the lower triangular part\n    // uses the online softmax algorithm\n    assert(T % 4  == 0);\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n    if(idx >= N * T) {\n        return;\n    }\n    int own_pos = idx % T;\n    int pos_by_4 = own_pos / 4;\n\n    // one row of inp, i.e. inp[idx, :] of shape (T,)\n    const float* x = inp + idx * T;\n\n    // not INF, so we don't get NaNs accidentally when subtracting two values.\n    float maxval = -FLT_MAX;\n    float sumval = 0.0f;\n\n    const float4* x_vec = reinterpret_cast<const float4*>(x);\n    for (int i = warp.thread_rank(); i < pos_by_4; i += warp.size()) {\n        float4 v = x_vec[i];\n        float old_maxval = maxval;\n        for(int k = 0; k < 4; ++k) {\n            maxval = fmaxf(maxval, vec_at(v, k));\n        }\n        sumval *= expf(inv_temperature * (old_maxval - maxval));\n        for(int k = 0; k < 4; ++k) {\n            sumval += expf(inv_temperature * (vec_at(v, k) - maxval));\n        }\n    }\n\n    if(4*pos_by_4 + warp.thread_rank() <= own_pos) {\n        float old_maxval = maxval;\n        maxval = fmaxf(maxval, x[4*pos_by_4 + warp.thread_rank()]);\n        sumval *= expf(inv_temperature * (old_maxval - maxval));\n        sumval += expf(inv_temperature * (x[4*pos_by_4 + warp.thread_rank()] - maxval));\n    }\n\n    float global_maxval = cg::reduce(warp, maxval, cg::greater<float>{});\n    sumval *= expf(inv_temperature * (maxval - global_maxval));\n\n    float sum = cg::reduce(warp, sumval, cg::plus<float>{});\n    float norm = 1.f / sum;\n\n    // divide the whole row by the sum\n    for (int i = warp.thread_rank(); i <= own_pos; i += warp.size()) {\n        // recalculation is faster than doing the round-trip through memory.\n        float ev = expf(inv_temperature * (__ldcs(x + i) - global_maxval));\n        __stcs(out + idx * T + i, ev * norm);\n    }\n}\n\n// naive kernel to backward through an autoregressive softmax, just to get correctness\n__global__ void softmax_autoregressive_backward_kernel1(float* dpreatt, const float* datt, const float* att,\n                                                     int B, int T, int C, int NH) {\n    // dpreatt, datt, att are all (B, NH, T, T)\n    int t3 = blockIdx.x * blockDim.x + threadIdx.x;\n    if (t3 < T) {\n        int hs = C / NH; // head size\n        float scale = 1.0f / sqrtf(hs);\n        for (int b = 0; b < B; b++) {\n            for (int h = 0; h < NH; h++) {\n                for (int t = t3; t < T; t++) {\n                    const float* att_bth = att + b*NH*T*T + h*T*T + t*T;\n                    const float* datt_bth = datt + b*NH*T*T + h*T*T + t*T;\n                    float* dpreatt_bth = dpreatt + b*NH*T*T + h*T*T + t*T;\n                    float accum = 0.0f;\n                    for (int t2 = 0; t2 <= t; t2++) {\n                        float indicator = t2 == t3 ? 1.0f : 0.0f;\n                        float local_derivative = att_bth[t2] * (indicator - att_bth[t3]);\n                        accum +=  scale * local_derivative * datt_bth[t2];\n                    }\n                    dpreatt_bth[t3] = accum;\n                }\n            }\n        }\n    }\n}\n\n// parallelize across t,b,h\n__global__ void softmax_autoregressive_backward_kernel2(float* dpreatt, const float* datt, const float* att,\n                                                     int B, int T, int C, int NH) {\n    int t3 = blockIdx.x * blockDim.x + threadIdx.x;\n    int idx = blockIdx.y * T * T;\n    if (t3 >= T) { return; }\n\n    int hs = C / NH; // head size\n    float scale = 1.0f / sqrtf(hs);\n    for (int t = t3; t < T; t++) {\n        float result = 0.0;\n        const float* att_bth = att + idx + t*T;\n        const float* datt_bth = datt + idx + t*T;\n        float* dpreatt_bth = dpreatt + idx + t*T;\n\n        for (int t2 = 0; t2 <= t; t2++) {\n            float indicator = t2 == t3 ? 1.0f : 0.0f;\n            float local_derivative = att_bth[t2] * (indicator - att_bth[t3]);\n            result += scale * local_derivative * datt_bth[t2];\n        }\n\n        dpreatt_bth[t3] = result;\n    }\n}\n\n// parallelize across t,b,h\n__global__ void softmax_autoregressive_backward_kernel3(float* dpreatt, const float* datt, const float* att,\n                                                     int B, int T, int C, int NH) {\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int t3 = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n\n    int idx = blockIdx.y * T * T;\n    if (t3 >= T) { return; }\n\n    int hs = C / NH; // head size\n    float scale = 1.0f / sqrtf(hs);\n    for (int t = t3; t < T; t++) {\n        float result = 0.0;\n        const float* att_bth = att + idx + t*T;\n        const float* datt_bth = datt + idx + t*T;\n        float* dpreatt_bth = dpreatt + idx + t*T;\n        const float att_at_t3 = att_bth[t3];\n\n        for (int t2 = warp.thread_rank(); t2 <= t; t2 += warp.size()) {\n            float indicator = t2 == t3 ? 1.0f : 0.0f;\n            float local_derivative = att_bth[t2] * (indicator - att_at_t3);\n            result += local_derivative * datt_bth[t2];\n        }\n\n        result = cg::reduce(warp, result, cg::plus<float>());\n        if(warp.thread_rank() == 0) {\n            dpreatt_bth[t3] = scale * result;\n        }\n    }\n}\n__global__ void softmax_autoregressive_backward_kernel4(float* __restrict__ dpreatt, const float* __restrict__ datt,\n                                                        const float* __restrict__ att,\n                                                        int B, int T, int C, int NH) {\n    constexpr int UNROLL = 8;\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int t3 = UNROLL * (blockIdx.x * warp.meta_group_size() + warp.meta_group_rank());\n\n    int idx = blockIdx.y * T * T;\n    if (t3 >= T) { return; }\n\n    int hs = C / NH; // head size\n    float scale = 1.0f / sqrtf(hs);\n\n    // the innermost loop combines different values of t2 with different values of t.\n    // by handling [t3, t3 + UNROLL) in one thread, we get much better memory reuse:\n    // any t3/t-dependent value can be loaded once before the t2 loop.\n    // within the t2 loop, we can combine each loaded value with each of the UNROLL\n    // pre-loaded values, thus cutting memory ready by a factor of ~UNROLL.\n\n    // one iteration of this loop has to handle the cases\n    // this may lead to some invalid indices; therefore, we have several\n    // early-outs in the iteration over k below.\n    for (int t = t3; t < T; t++) {\n        float result[UNROLL] = {};\n        const float* att_bth = att + idx + t * T;\n        const float* datt_bth = datt + idx + t * T;\n        float* dpreatt_bth = dpreatt + idx + t * T;\n\n        float att_at_t3[UNROLL];\n        for(int k = 0; k < UNROLL; ++k) {\n            if (t < t3 + k) continue;\n            att_at_t3[k] = att_bth[t3 + k];\n        }\n\n        for (int t2 = warp.thread_rank(); t2 <= t; t2 += warp.size()) {\n            float att_t2 = att_bth[t2];\n            float datt_t2 = datt_bth[t2];\n            for(int k = 0; k < UNROLL; ++k) {\n                if (t < t3 + k) continue;\n                float indicator = t2 == (t3 + k) ? 1.0f : 0.0f;\n                float local_derivative = att_t2 * (indicator - att_at_t3[k]);\n                result[k] += local_derivative * datt_t2;\n            }\n        }\n\n        for(int k = 0; k < UNROLL; ++k) {\n            result[k] = cg::reduce(warp, result[k], cg::plus<float>());\n        }\n        if (warp.thread_rank() < UNROLL) {\n            dpreatt_bth[t3 + warp.thread_rank()] = scale * result[warp.thread_rank()];\n        }\n    }\n}\n\n__global__ void softmax_autoregressive_backward_kernel5(float* __restrict__ dpreatt, const float* __restrict__ datt,\n                                                        const float* __restrict__ att,\n                                                        int B, int T, int C, int NH) {\n    constexpr int UNROLL = 8;\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int t3 = UNROLL * (blockIdx.x * warp.meta_group_size() + warp.meta_group_rank());\n\n    int idx = blockIdx.y * T * T;\n    if (t3 >= T) { return; }\n\n    int hs = C / NH; // head size\n    float scale = 1.0f / sqrtf(hs);\n    for (int t = t3; t < T; t++) {\n        float result[UNROLL] = {};\n        const float* att_bth = att + idx + t * T;\n        const float* datt_bth = datt + idx + t * T;\n        float* dpreatt_bth = dpreatt + idx + t * T;\n\n        float att_at_t3[UNROLL];\n        for(int k = 0; k < UNROLL; ++k) {\n            // if t < t3+k, we're out of bounds.\n            // in that case, we don't care what we read, because later on,\n            // we won't write the corresponding result. So just clip to\n            // make sure this is a valid (in-bounds) memory access.\n            att_at_t3[k] = att_bth[min(t, t3 + k)];\n        }\n\n        // the code below is actually just a for loop; except,\n        // we have to do something special in one iteration in\n        // the middle, and an if turned out to have significant\n        // performance impact.\n        // so we split the loop in three parts. Ugly, but effective.\n\n        // the beginning/end loop does the same thing, so we write the code\n        // just once in a lambda. In this step, we're guaranteed that\n        // indicator == 0\n        auto loop_step = [&](int t2){\n            float p = att_bth[t2] * datt_bth[t2];\n            for (int k = 0; k < UNROLL; ++k) {\n                result[k] -= p * att_at_t3[k];\n            }\n        };\n\n        // Now the actual loop.\n        {\n            // declare the loop iterator. Needs to be kept across the\n            // three different parts, so it's not a local variable in\n            // the for loop.\n            int t2 = warp.thread_rank();\n\n            // first part, as long as t2 < t3, indicator == 0\n            for (; t2 < t3; t2 += warp.size()) {\n                loop_step(t2);\n            }\n\n            // because k <= warp.size() (==32), the event that t3+k == t2\n            // has to happen at this particular step.\n            static_assert(UNROLL <= 32, \"UNROLL is too large, this won't produce correct results.\");\n            if (t2 <= t) {\n                float att_t2 = att_bth[t2];\n                float datt_t2 = datt_bth[t2];\n                float p = att_t2 * datt_t2;\n                for (int k = 0; k < UNROLL; ++k) {\n                    float indicator = t2 == (t3 + k) ? 1.0f : 0.0f;\n                    result[k] += p * (indicator - att_at_t3[k]);\n                }\n                t2 += warp.size();\n            }\n\n            // rest of the loop, indicator == 0 again\n            for (; t2 <= t; t2 += warp.size()) {\n                loop_step(t2);\n            }\n        }\n\n        for(int k = 0; k < UNROLL; ++k) {\n            result[k] = cg::reduce(warp, result[k], cg::plus<float>());\n        }\n\n        // when storing, we need to check that this is actually a valid result.\n        // here, warp.thread_rank() corresponds to `k` in the previous loops.\n        if (warp.thread_rank() < UNROLL && t >= t3 + warp.thread_rank()) {\n            dpreatt_bth[t3 + warp.thread_rank()] = scale * result[warp.thread_rank()];\n        }\n    }\n}\n\n\n// I want `BlockSize` to be statically known to the compiler, thus we get a template here.\n// This kernel takes a step back, and looks at the original CPU code again. We have some simple outer loops\n// That are independent, (b, t, h), and then the inner loops over (t2, t3) where we're combining elements -- this is\n// where we can reuse data and be more efficient\n// => handle b, t, h  through block indices; each block does all the work for the (t2, t3) loop cooperatively.\n// Now we have two nested loops, and in the inner instruction, we combine indexing from both => this calls for\n// loop tiling, and lifting some of the memory ops out of the loop.\n// We're in luck here;  if we tile so that t3 is the outer loop, we can get a sinlge write op per result, AND also cache\n// the t2-indexed part of the computation, which is the problematic one because it contains a multiplication that now we\n// do not have to repeat over and over.\n// => do an outer t3 loop where each thread gets one t3 index. Then, do an outer t2 loop in steps of BlockSize, and\n// prepare BlockSize many elements for the inner loop. Here, each thread calculates one element and stores it in shmem.\n// Then, in the inner t2 loop, each thread reads *all* the elements previously stored and does its computations.\n// This way, we do 3*BlockSize loads, but BlockSize^2 computation steps => This kernel is now entirely compute bound.\n// To fix up the compute issues, as above, we replace ifs in memory reading with min, and also split the inner loop\n// into a large region where we don't have to calculate the indicator, and a small, costly region where we do.\ntemplate<int BlockSize>\n__global__ void __launch_bounds__(BlockSize) softmax_autoregressive_backward_kernel6(float* dpreatt, const float* datt, const float* att,\n                                                        int B, int T, int C, int NH) {\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    __shared__ float att_bth_s[BlockSize];\n\n    int idx = blockIdx.y;\n    int t = blockIdx.x;\n\n    att += idx * T * T;\n    datt += idx * T * T;\n    dpreatt += idx * T * T;\n\n    int hs = C / NH; // head size\n    float scale = 1.0f / sqrtf(hs);\n    const float* att_bth = att + t * T;\n    const float* datt_bth = datt + t * T;\n    float* dpreatt_bth = dpreatt + t * T;\n\n    int block_steps = ceil_div(t+1, BlockSize);\n    // very important: This loop condition needs to be the same for all threads.\n    // even if a thread later on is not going to do any work, it needs to participate in the\n    // data loading process!\n    for (int t3f = 0; t3f < block_steps; ++t3f) {\n        int t3 = t3f * BlockSize + block.thread_rank();\n        float acc = 0.f;\n        float at3 = att_bth[t3];\n        for (int t2b = 0; t2b <= t; t2b += BlockSize) {\n            int end = min(t + 1 - t2b, BlockSize);\n            block.sync();\n            {\n                int t2i = block.thread_rank();\n                int t2 = min(t, t2b + t2i);\n                att_bth_s[t2i] = att_bth[t2] * datt_bth[t2];\n            }\n\n            block.sync();\n            if(t3f * BlockSize == t2b) {\n                for (int t2i = 0; t2i < end; t2i++) {\n                    int t2 = t2b + t2i;\n                    float indicator = t2 == t3 ? 1.0f : 0.0f;\n                    acc += att_bth_s[t2i] * (indicator - at3);\n                }\n            } else {\n                for (int t2i = 0; t2i < end; t2i++) {\n                    acc +=  att_bth_s[t2i] * (0.f - at3);\n                }\n            }\n        }\n        dpreatt_bth[t3] = scale * acc;\n    }\n}\n\n// Actually disentangling the loops and simplifying the resulting math gives us this pretty nice kernel.\ntemplate<int BlockSize>\n__global__ void softmax_autoregressive_backward_kernel7(float* dpreatt, const float* datt, const float* att,\n                                                        int B, int T, int C, float scale) {\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    __shared__ float block_acc[32];\n\n    int idx = blockIdx.y;\n    int t = blockIdx.x;\n\n    att += idx * T * T;\n    datt += idx * T * T;\n    dpreatt += idx * T * T;\n\n    const float* att_bth = att + t * T;\n    const float* datt_bth = datt + t * T;\n    float* dpreatt_bth = dpreatt + t * T;\n\n    if(warp.meta_group_rank() == 0) {\n        block_acc[warp.thread_rank()] = 0;\n    }\n\n    float local_sum = 0;\n    for(int t2 = block.thread_rank(); t2 <= t; t2 += BlockSize) {\n        local_sum += att_bth[t2] * datt_bth[t2];\n    }\n\n    block_acc[warp.meta_group_rank()] = cg::reduce(warp, local_sum, cg::plus<float>{});\n    block.sync();\n    local_sum = cg::reduce(warp, block_acc[warp.thread_rank()], cg::plus<float>{});\n\n    for (int t3 = block.thread_rank(); t3 <= t; t3 += BlockSize) {\n        float acc = att_bth[t3] * (datt_bth[t3] - local_sum);\n        dpreatt_bth[t3] = scale * acc;\n    }\n}\n\n// The slightly less pretty version of kernel 7. Adding in all the dirty tricks that can give us a few more percent\n//  - streaming memory access instructions\n//  - reordering blocks to prevent tail effect\n//  - multiple values of T per block\ntemplate<int BlockSize>\n__global__ void softmax_autoregressive_backward_kernel8(float* dpreatt, const float* datt, const float* att,\n                                                        int B, int T, int C, float scale) {\n    namespace cg = cooperative_groups;\n    constexpr int T_per_block = 4;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    __shared__ float block_acc[32];\n\n    int idx = blockIdx.y;\n    // go through blocks in reverse order, so the slowest block starts first\n    int t0 = T - 1 - T_per_block*blockIdx.x;\n\n    att += idx * T * T;\n    datt += idx * T * T;\n    dpreatt += idx * T * T;\n\n    if (warp.meta_group_rank() == 0) {\n        block_acc[warp.thread_rank()] = 0;\n    }\n\n    for(int to = 0; to < T_per_block; ++to) {\n        int t = t0 - to;\n        if(t < 0) return;\n        const float* att_bth = att + t * T;\n        const float* datt_bth = datt + t * T;\n        float* dpreatt_bth = dpreatt + t * T;\n\n        float local_sum = 0;\n        for (int t2 = block.thread_rank(); t2 <= t; t2 += BlockSize) {\n            local_sum += att_bth[t2] * datt_bth[t2];\n        }\n\n        block_acc[warp.meta_group_rank()] = cg::reduce(warp, local_sum, cg::plus<float>{});\n        block.sync();\n        local_sum = cg::reduce(warp, block_acc[warp.thread_rank()], cg::plus<float>{});\n\n        for (int t3 = block.thread_rank(); t3 <= t; t3 += BlockSize) {\n            // don't touch the cache. Some parts will still be here from the previous loop, and\n            // we want to exploit those.\n            float acc = __ldcs(att_bth + t3) * (__ldcs(datt_bth + t3) - local_sum);\n            __stcs(dpreatt_bth + t3, scale * acc);\n        }\n    }\n}\n\n\n// ----------------------------------------------------------------------------\n// kernel launchers\n\n// attention forward pass kernel\nvoid attention_forward(float* out, float* vaccum, float* qkvr, float* preatt, float* att,\n                       const float* inp,\n                       int B, int T, int C, int NH,\n                       const int block_size) {\n    // inp is (B, T, 3C) QKV\n    // preatt, att are (B, NH, T, T)\n    // output is (B, T, C)\n    int HS = C / NH; // head size\n\n    // permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS)\n    float *q, *k, *v;\n    q = qkvr + 0 * B * T * C;\n    k = qkvr + 1 * B * T * C;\n    v = qkvr + 2 * B * T * C;\n    int total_threads = B * NH * T * HS;\n    int num_blocks = ceil_div(total_threads, block_size);\n    permute_kernel<<<num_blocks, block_size>>>(q, k, v, inp, B, T, NH, HS);\n\n    // batched matrix multiply with cuBLAS\n    const float alpha = 1.0f;\n    const float beta = 0.0f;\n    cublasCheck(cublasSgemmStridedBatched(cublas_handle,\n                                     CUBLAS_OP_T, CUBLAS_OP_N,\n                                     T, T, HS,\n                                     &alpha,\n                                     k, HS, T * HS,\n                                     q, HS, T * HS,\n                                     &beta,\n                                     preatt, T, T * T,\n                                     B * NH));\n\n    // multiply all elements of preatt elementwise by scale\n    float scale = 1.0 / sqrtf(HS);\n    int softmax_block_size = 256;\n    int grid_size = ceil_div(B * NH * T * 32, softmax_block_size);\n    softmax_forward_kernel5<<<grid_size, softmax_block_size>>>(att, scale, preatt, B * NH, T);\n\n    // new approach: first cuBLAS another batched matmul\n    // vaccum = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)\n    cublasCheck(cublasSgemmStridedBatched(cublas_handle,\n                                     CUBLAS_OP_N, CUBLAS_OP_N,\n                                     HS, T, T,\n                                     &alpha,\n                                     v, HS, T * HS,\n                                     att, T, T * T,\n                                     &beta,\n                                     vaccum, HS, T * HS,\n                                     B * NH));\n\n    // now unpermute\n    // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side\n    num_blocks = ceil_div(B * T * C, block_size);\n    unpermute_kernel<<<num_blocks, block_size>>>(vaccum, out, B, T, NH, HS);\n}\n\nvoid launch_softmax_1(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) {\n    int num_blocks = ceil_div(T, block_size);\n    softmax_autoregressive_backward_kernel1<<<dim3(num_blocks, B*NH), block_size>>>(dpreatt, datt, att, B, T, C, NH);\n}\n\nvoid launch_softmax_2(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) {\n    int num_blocks = ceil_div(T, block_size);\n    softmax_autoregressive_backward_kernel2<<<dim3(num_blocks, B*NH), block_size>>>(dpreatt, datt, att, B, T, C, NH);\n}\n\nvoid launch_softmax_3(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) {\n    int num_blocks = ceil_div(32*T, block_size);\n    softmax_autoregressive_backward_kernel3<<<dim3(num_blocks, B*NH), block_size>>>(dpreatt, datt, att, B, T, C, NH);\n}\n\nvoid launch_softmax_4(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) {\n    int num_blocks = ceil_div(32/8*T, block_size);\n    softmax_autoregressive_backward_kernel4<<<dim3(num_blocks, B*NH), block_size>>>(dpreatt, datt, att, B, T, C, NH);\n}\n\nvoid launch_softmax_5(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) {\n    int num_blocks = ceil_div(32/8*T, block_size);\n    softmax_autoregressive_backward_kernel5<<<dim3(num_blocks, B*NH), block_size>>>(dpreatt, datt, att, B, T, C, NH);\n}\n\ntemplate<class Launcher>\nvoid dispatch_launch(Launcher&& launch, int block_size) {\n    switch(block_size) {\n        case 32:\n            return launch(std::integral_constant<int, 32>{});\n        case 64:\n            return launch(std::integral_constant<int, 64>{});\n        case 128:\n            return launch(std::integral_constant<int, 128>{});\n        case 256:\n            return launch(std::integral_constant<int, 256>{});\n        case 512:\n            return launch(std::integral_constant<int, 512>{});\n        case 1024:\n            return launch(std::integral_constant<int, 1024>{});\n        default:\n            assert(false && \"Invalid block size\");\n    }\n}\n\nvoid launch_softmax_6(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) {\n    auto launch = [&](auto int_const) {\n        softmax_autoregressive_backward_kernel6<int_const.value><<<dim3(T, B * NH), int_const.value>>>(dpreatt, datt, att, B, T, C, NH);\n    };\n    dispatch_launch(launch, block_size);\n}\n\nvoid launch_softmax_7(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) {\n    int hs = C / NH; // head size\n    float scale = 1.0f / sqrtf(hs);\n    auto launch = [&](auto int_const) {\n        constexpr int block_size = int_const.value;\n        softmax_autoregressive_backward_kernel7<block_size><<<dim3(T, B * NH), block_size>>>\n                                                              (dpreatt, datt, att, B, T, C, scale);\n    };\n    dispatch_launch(launch, block_size);\n}\n\nvoid launch_softmax_8(float* dpreatt, float* datt, const float* att, int B, int T, int C, int NH, int block_size) {\n    int hs = C / NH; // head size\n    float scale = 1.0f / sqrtf(hs);\n    auto launch = [&](auto int_const) {\n        constexpr int block_size = int_const.value;\n        softmax_autoregressive_backward_kernel8<block_size><<<dim3(T / 4, B * NH), block_size>>>\n                                                              (dpreatt, datt, att, B, T, C, scale);\n    };\n    dispatch_launch(launch, block_size);\n}\n\n// the sequence of transformations in this compound op is:\n// inp (B,T,3C) -> qkvr (B,T,3C) -> preatt (B,NH,T,T) -> att (B,NH,T,T) -> vaccum (B,T,C) -> out (B,T,C)\ntemplate<class SoftmaxKernel>\nvoid attention_backward1(float* dinp, float* dqkvr, float* dpreatt, float* datt, float* dvaccum,\n                        const float* dout,\n                        const float* inp, const float* qkvr, const float* preatt, const float* att, const float* vaccum,\n                        int B, int T, int C, int NH,\n                        SoftmaxKernel softmax_autoregressive_backward,\n                        const int block_size) {\n    int HS = C / NH; // head size\n    const float alpha = 1.0f;\n    const float beta = 1.0f; // note beta = 1.0f so that we accumulate gradients (+=)\n    // unpack convenience pointers into q, k, v\n    const float *q, *k, *v;\n    q = qkvr + 0 * B * T * C;\n    k = qkvr + 1 * B * T * C;\n    v = qkvr + 2 * B * T * C;\n    float *dq, *dk, *dv;\n    dq = dqkvr + 0 * B * T * C;\n    dk = dqkvr + 1 * B * T * C;\n    dv = dqkvr + 2 * B * T * C;\n\n    // backward through the unpermute operation\n    int num_blocks = ceil_div(B * T * C, block_size);\n    unpermute_kernel_backward<<<num_blocks, block_size>>>(dvaccum, dout, B, T, NH, HS);\n    cudaCheck(cudaGetLastError());\n\n    // backward into datt\n    cublasCheck(cublasSgemmStridedBatched(cublas_handle,\n                            CUBLAS_OP_T, CUBLAS_OP_N,\n                            T, T, HS,\n                            &alpha,\n                            v, HS, T * HS,\n                            dvaccum, HS, T * HS,\n                            &beta,\n                            datt, T, T * T,\n                            B * NH));\n\n    // backward into dv\n    cublasCheck(cublasSgemmStridedBatched(cublas_handle,\n            CUBLAS_OP_N, CUBLAS_OP_T,\n            HS, T, T,\n            &alpha,\n            dvaccum, HS, T * HS,\n            att, T, T * T,\n            &beta,\n            dv, HS, T * HS,\n            B * NH));\n\n    // backward into preatt\n    softmax_autoregressive_backward(dpreatt, datt, att, B, T, C, NH, block_size);\n    cudaCheck(cudaGetLastError());\n\n    // backward into q\n    cublasCheck(cublasSgemmStridedBatched(cublas_handle,\n                            CUBLAS_OP_N, CUBLAS_OP_N,\n                            HS, T, T,\n                            &alpha,\n                            k, HS, T * HS,\n                            dpreatt, T, T * T,\n                            &beta,\n                            dq, HS, T * HS,\n                            B * NH));\n    // backward into k\n    cublasCheck(cublasSgemmStridedBatched(cublas_handle,\n                            CUBLAS_OP_N, CUBLAS_OP_T,\n                            HS, T, T,\n                            &alpha,\n                            q, HS, T * HS,\n                            dpreatt, T, T * T,\n                            &beta,\n                            dk, HS, T * HS,\n                            B * NH));\n\n    // backward into inp\n    num_blocks = ceil_div(B * NH * T * HS, block_size);\n    permute_kernel_backward<<<num_blocks, block_size>>>(dinp, dq, dk, dv, B, T, NH, HS);\n    cudaCheck(cudaGetLastError());\n}\n\n// kernel version dispatch\nvoid attention_backward(int kernel_num,\n                        float* dinp, float* dqkvr, float* dpreatt, float* datt, float* dvaccum,\n                        const float* dout,\n                        const float* inp, const float* qkvr, const float* preatt, const float* att, const float* vaccum,\n                        int B, int T, int C, int NH,\n                        const int block_size) {\n    switch (kernel_num) {\n        case 1:\n            attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH,\n                                launch_softmax_1, block_size);\n            break;\n        case 2:\n            attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH,\n                                launch_softmax_2, block_size);\n            break;\n        case 3:\n            attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH,\n                                launch_softmax_3, block_size);\n            break;\n        case 4:\n            attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH,\n                                launch_softmax_4, block_size);\n            break;\n        case 5:\n            attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH,\n                                launch_softmax_5, block_size);\n            break;\n        case 6:\n            attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH,\n                                launch_softmax_6, block_size);\n            break;\n        case 7:\n            attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH,\n                                launch_softmax_7, block_size);\n            break;\n        case 8:\n            attention_backward1(dinp, dqkvr, dpreatt, datt, dvaccum, dout, inp, qkvr, preatt, att, vaccum, B, T, C, NH,\n                                launch_softmax_8, block_size);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n// ----------------------------------------------------------------------------\n\nint main(int argc, char **argv) {\n    setup_main();\n\n    // hyperparameters\n    int B = 4;\n    int T = 1024;\n    int C = 768;\n    int NH = 12;\n\n    // read kernel_num from command line\n    int kernel_num = 1;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    // create the host memory for the forward pass\n    float* inp = make_random_float(B * T * 3 * C);\n    float* qkvr = (float*)malloc(B * T * 3 * C * sizeof(float));\n    float* preatt = (float*)malloc(B * NH * T * T * sizeof(float));\n    float* att = (float*)malloc(B * NH * T * T * sizeof(float));\n    float* vaccum = (float*)malloc(B * T * C * sizeof(float));\n    float* out = (float*)malloc(B * T * C * sizeof(float));\n\n    // execute the forward pass on the CPU\n    attention_forward_cpu(out, preatt, att, inp, B, T, C, NH);\n\n    // create device memory for the forward pass\n    float *d_inp, *d_qkvr, *d_preatt, *d_att, *d_vaccum, *d_out;\n    cudaCheck(cudaMalloc(&d_inp, B * T * 3 * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_qkvr, B * T * 3 * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_preatt, B * NH * T * T * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_att, B * NH * T * T * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_vaccum, B * T * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(float)));\n    // copy over the input\n    cudaCheck(cudaMemcpy(d_inp, inp, B * T * 3 * C * sizeof(float), cudaMemcpyHostToDevice));\n\n    // execute the forward pass on the GPU\n    const int block_size = 256;\n    attention_forward(d_out, d_vaccum, d_qkvr, d_preatt, d_att, d_inp, B, T, C, NH, block_size);\n\n    // check that preatt, att, and out match between the CPU and GPU versions\n    printf(\"Checking the forward pass CPU <-> GPU...\\n\");\n    printf(\"[preatt]\\n\"); validate_result(d_preatt, preatt, \"preatt\", B * T * C, 5e-3f);\n    printf(\"[att]\\n\");    validate_result(d_att, att, \"att\", B * T * C, 1e-3f);\n    printf(\"[out]\\n\");    validate_result(d_out, out, \"out\", B * T * C, 1e-3f);\n\n    // set up the memory for the backward pass\n    float* dout = make_random_float(B * T * C); // the gradients on the output\n    float* dinp = make_zeros_float(B * T * 3 * C); // zeros for all else, to += into\n    float* dpreatt = make_zeros_float(B * NH * T * T);\n    float* datt = make_zeros_float(B * NH * T * T);\n\n    // call backward() on the CPU to get our reference gradients\n    attention_backward_cpu(dinp, dpreatt, datt, dout, inp, att, B, T, C, NH);\n\n    // create device memory for the backward pass\n    float *d_dinp, *d_dqkvr, *d_dpreatt, *d_datt, *d_dvaccum, *d_dout;\n    cudaCheck(cudaMalloc(&d_dinp, B * T * 3 * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_dqkvr, B * T * 3 * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_dpreatt, B * NH * T * T * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_datt, B * NH * T * T * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_dvaccum, B * T * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_dout, B * T * C * sizeof(float)));\n    // copy over the dout gradients that starts the backprop chain\n    cudaCheck(cudaMemcpy(d_dout, dout, B * T * C * sizeof(float), cudaMemcpyHostToDevice));\n    // memset all the other memory to zeros, to += into\n    cudaCheck(cudaMemset(d_dinp, 0, B * T * 3 * C * sizeof(float)));\n    cudaCheck(cudaMemset(d_dqkvr, 0, B * T * 3 * C * sizeof(float)));\n    cudaCheck(cudaMemset(d_dpreatt, 0, B * NH * T * T * sizeof(float)));\n    cudaCheck(cudaMemset(d_datt, 0, B * NH * T * T * sizeof(float)));\n    cudaCheck(cudaMemset(d_dvaccum, 0, B * T * C * sizeof(float)));\n\n    // call backward() on the GPU\n    attention_backward(kernel_num, d_dinp, d_dqkvr, d_dpreatt, d_datt, d_dvaccum,\n                       d_dout, d_inp, d_qkvr, d_preatt, d_att, d_vaccum,\n                       B, T, C, NH, block_size);\n\n    // check that the gradients match between the CPU and GPU versions\n    // note that we will only check the correctness at [att, preatt, inp]\n    // the gradients at qkvr and vaccum will remain unchecked, but are\n    // assumed to be correct if the other gradients are correct\n    printf(\"Checking the backward pass CPU <-> GPU...\\n\");\n    printf(\"[datt]\\n\");    validate_result(d_datt, datt, \"datt\", B * NH * T * T, 5e-3f);\n    printf(\"[dpreatt]\\n\"); validate_result(d_dpreatt, dpreatt, \"dpreatt\", B * NH * T * T, 1e-3f);\n    printf(\"[dinp]\\n\");    validate_result(d_dinp, dinp, \"dinp\", B * T * 3 * C, 1e-3f);\n\n    // also let's manually step through the gradients here\n    float* h_dinp = (float*)malloc(B * T * 3 * C * sizeof(float));\n    cudaCheck(cudaMemcpy(h_dinp, d_dinp, B * T * 3 * C * sizeof(float), cudaMemcpyDeviceToHost));\n    int num_match = 0;\n    int num_no_match = 0;\n    int num_zero_grad = 0;\n    int HS = C / NH;\n    for (int i = 0; i < B * T * 3 * C; i++) {\n\n        // the dimensions of inp are (B, T, 3, NH, HS)\n        // where B = batch, T = time, 3 = qkv, NH = num heads, HS = head size\n        // unpack the individual b,t,qkvix,h,c indices\n        int ix = i;\n        int c = ix % HS;\n        ix /= HS;\n        int h = ix % NH;\n        ix /= NH;\n        int qkvix = ix % 3;\n        ix /= 3;\n        int t = ix % T;\n        ix /= T;\n        int b = ix;\n\n        float diff = fabs(dinp[i] - h_dinp[i]);\n\n        // attempt to index at random\n        if (b == 1 && t == 5 && c == 23 && h == 2) {\n            printf(\"ix %5d [b=%4d, t=%4d, qkv=%4d, nh=%4d, hs=%4d]: ref: %f gpu: %f\\n\", i, b, t, qkvix, h, c, dinp[i], h_dinp[i]);\n        }\n\n        if (diff > 1e-4f) {\n            num_no_match++;\n        } else {\n            num_match++;\n        }\n\n        if (dinp[i] == 0.0f) {\n            num_zero_grad++;\n        }\n    }\n    printf(\"Number of matching gradients: %d (%.2f%% of total)\\n\", num_match, 100*(float)num_match / (B * T * 3 * C));\n    printf(\"Number of non-matching gradients: %d (%.2f%% of total)\\n\", num_no_match, 100*(float)num_no_match / (B * T * 3 * C));\n    printf(\"Number of gradients that are exactly zero: %d (%.2f%% of total)\\n\", num_zero_grad, 100*(float)num_zero_grad / (B * T * 3 * C));\n\n    // final verdict\n    printf(\"All results match. Starting benchmarks.\\n\\n\");\n\n    // benchmark speed of the kernel\n    int block_sizes[] = {32, 64, 128, 256, 512, 1024};\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        int repeat_times = 10;\n        float elapsed_time = benchmark_kernel(repeat_times, attention_backward,\n                                              kernel_num, d_dinp, d_dqkvr, d_dpreatt, d_datt, d_dvaccum,\n                                              d_dout, d_inp, d_qkvr, d_preatt, d_att, d_vaccum,\n                                              B, T, C, NH, block_size);\n\n        printf(\"block_size %4d | time %f ms\\n\", block_size, elapsed_time);\n    }\n\n    // free memory\n    free(inp);\n    free(qkvr);\n    free(preatt);\n    free(att);\n    free(vaccum);\n    free(out);\n    free(dout);\n    free(dinp);\n    free(dpreatt);\n    free(datt);\n    free(h_dinp);\n    cudaCheck(cudaFree(d_inp));\n    cudaCheck(cudaFree(d_qkvr));\n    cudaCheck(cudaFree(d_preatt));\n    cudaCheck(cudaFree(d_att));\n    cudaCheck(cudaFree(d_vaccum));\n    cudaCheck(cudaFree(d_out));\n    cudaCheck(cudaFree(d_dinp));\n    cudaCheck(cudaFree(d_dqkvr));\n    cudaCheck(cudaFree(d_dpreatt));\n    cudaCheck(cudaFree(d_datt));\n    cudaCheck(cudaFree(d_dvaccum));\n    cudaCheck(cudaFree(d_dout));\n    cublasDestroy(cublas_handle);\n    return 0;\n}"
  },
  {
    "path": "dev/cuda/attention_forward.cu",
    "content": "/*\nKernels for attention forward pass.\n\nIf you do not have CUDNN, you can remove ENABLE_CUDNN to run the other kernels\n\nSee the README for cuDNN install instructions\n\nCompile example with cuDNN:\nnvcc -I/PATH/TO/cudnn-frontend/include -DENABLE_CUDNN -O3 --use_fast_math --lcublas -lcublasLt -lcudnn attention_forward.cu -o attention_forward\n\nCompile example without cuDNN:\nnvcc -O3 --use_fast_math -lcublas -lcublasLt attention_forward.cu -o attention_forward\n\nversion 1 is naive port from CPU code to kernel, parallelize over batch, time, heads only\n./attention_forward 1\n\nversion 2 is a naive implementation of flash attention, taken, adapted from\nhttps://github.com/tspeterkim/flash-attention-minimal\nand with help from\nhttps://github.com/leloykun/flash-hyperbolic-attention-minimal\nsadly, this flash attention version seems about 3X slower than the naive version\n./attention_forward 2\n\nversion 3 is a cuBLAS + softmax version, similar to the PyTorch implementation\ncuBLAS is used both to calculate the QK^T and the final weighted sum\nthe softmax is calculated using a custom, efficient kernel as well\nthis turns out to be ~20X faster than (1) nice\n./attention_forward 3\n\nversion 4 is a further optimized kernel that fuses the scale operation,\nuses a directly autoregressive softmax, and uses the online softmax algorithm.\n./attention_forward 4\n\nversion 5 is a FP16 version of kernel 4\n./attention_forward 5\n\nversion 6 is kernel 5 skipping (un)permute (unrealistic but useful comparison point)\n\nversion 10 is using cuDNN Flash Attention using FP16 or BF16, see:\nhttps://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md\n./attention_forward 10\n\nversion 11 is kernel 10 skipping FP16/FP32 conversions (full FP16/BF16 network)\n./attention_forward 11\n*/\n//#define ENABLE_CUDNN // can be enabled via nvcc \"-DENABLE_CUDNN\"\n#include <stdio.h>\n#include <stdlib.h>\n#include <assert.h>\n#include <float.h>\n#include <cublas_v2.h>\n#include <cuda_runtime.h>\n#include <cuda_bf16.h>\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n\n#define ENABLE_BF16\n#include \"common.h\"\n\n// ----------------------------------------------------------------------------\n// CUDA & cuDNN setup\nstatic bool first_run_validation = true; // always run e.g. permute on 1st run\n\n#ifdef ENABLE_CUDNN\n#include <cudnn_frontend.h>\nnamespace fe = cudnn_frontend;\n#if CUBLAS_LOWP == CUDA_R_16BF\n#define CUDNN_16BIT fe::DataType_t::BFLOAT16\n#else\n#define CUDNN_16BIT fe::DataType_t::HALF\n#endif\n\nstatic cudnnHandle_t cudnn_handle;\nstatic size_t cudnn_workspace_size = 0; // dynamically allocated as needed (up to 256MiB!)\nstatic void* cudnn_workspace = NULL;\n\n#define checkCudaErr(err) assert((int)err == 0);\n#define checkCudnnErr(err) assert((int)err == 0);\n#endif // ENABLE_CUDNN\n// ----------------------------------------------------------------------------\n// CPU code reference\n\nvoid attention_forward_cpu(float* out, float* preatt, float* att,\n                       const float* inp,\n                       int B, int T, int C, int NH) {\n    // input is (B, T, 3C) Q,K,V\n    // preatt, att are (B, NH, T, T)\n    // output is (B, T, C)\n    int C3 = C*3;\n    int hs = C / NH; // head size\n    float scale = 1.0 / sqrtf(hs);\n\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            for (int h = 0; h < NH; h++) {\n                const float* query_t = inp + b * T * C3 + t * C3 + h * hs;\n                float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T;\n                float* att_bth = att + b*NH*T*T + h*T*T + t*T;\n\n                // pass 1: calculate query dot key and maxval\n                float maxval = -FLT_MAX;\n                for (int t2 = 0; t2 <= t; t2++) {\n                    const float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key\n\n                    // (query_t) dot (key_t2)\n                    float val = 0.0f;\n                    for (int i = 0; i < hs; i++) {\n                        val += query_t[i] * key_t2[i];\n                    }\n                    val *= scale;\n                    if (val > maxval) {\n                        maxval = val;\n                    }\n\n                    preatt_bth[t2] = val;\n                }\n                // pad with -INFINITY outside of autoregressive region for debugging comparisons\n                for (int t2 = t+1; t2 < T; t2++) {\n                    preatt_bth[t2] = -INFINITY;\n                }\n\n                // pass 2: calculate the exp and keep track of sum\n                float expsum = 0.0f;\n                for (int t2 = 0; t2 <= t; t2++) {\n                    float expv = expf(preatt_bth[t2] - maxval);\n                    expsum += expv;\n                    att_bth[t2] = expv;\n                }\n                float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum;\n\n                // pass 3: normalize to get the softmax\n                for (int t2 = 0; t2 < T; t2++) {\n                    if (t2 <= t) {\n                        att_bth[t2] *= expsum_inv;\n                    } else {\n                        // causal attention mask. not strictly necessary to set to zero here\n                        // only doing this explicitly for debugging and checking to PyTorch\n                        att_bth[t2] = 0.0f;\n                    }\n                }\n\n                // pass 4: accumulate weighted values into the output of attention\n                float* out_bth = out + b * T * C + t * C + h * hs;\n                for (int i = 0; i < hs; i++) { out_bth[i] = 0.0f; }\n                for (int t2 = 0; t2 <= t; t2++) {\n                    const float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value\n                    float att_btht2 = att_bth[t2];\n                    for (int i = 0; i < hs; i++) {\n                        out_bth[i] += att_btht2 * value_t2[i];\n                    }\n                }\n            }\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n\n__global__ void attention_query_key_kernel1(float* preatt, const float* inp,\n                                           int B, int T, int C, int NH) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    int total_threads = B * NH * T * T;\n\n    if (idx < total_threads) {\n        int t2 = idx % T;\n        int t = (idx / T) % T;\n        if (t2 > t) {\n            // autoregressive mask\n            preatt[idx] = -INFINITY;\n            return;\n        }\n        int h = (idx / (T * T)) % NH;\n        int b = idx / (NH * T * T);\n\n        int C3 = C*3;\n        int hs = C / NH; // head size\n        const float* query_t = inp + b * T * C3 + t * C3 + h * hs;\n        const float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key\n\n        // (query_t) dot (key_t2)\n        float val = 0.0f;\n        for (int i = 0; i < hs; i++) {\n            val += query_t[i] * key_t2[i];\n        }\n        val *= 1.0 / sqrtf(hs);\n\n        preatt[idx] = val;\n    }\n}\n\n__global__ void attention_softmax_kernel1(float* att, const float* preatt,\n                                         int B, int T, int NH) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    int total_threads = B * T * NH;\n\n    if (idx < total_threads) {\n        int h = idx % NH;\n        int t = (idx / NH) % T;\n        int b = idx / (NH * T);\n\n        const float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T;\n        float* att_bth = att + b*NH*T*T + h*T*T + t*T;\n\n        // find maxval\n        float maxval = -FLT_MAX;\n        for (int t2 = 0; t2 <= t; t2++) {\n            if (preatt_bth[t2] > maxval) {\n                maxval = preatt_bth[t2];\n            }\n        }\n\n        // calculate the exp and keep track of sum\n        float expsum = 0.0f;\n        for (int t2 = 0; t2 <= t; t2++) {\n            float expv = expf(preatt_bth[t2] - maxval);\n            expsum += expv;\n            att_bth[t2] = expv;\n        }\n        float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum;\n\n        // normalize to get the softmax\n        for (int t2 = 0; t2 < T; t2++) {\n            if (t2 <= t) {\n                att_bth[t2] *= expsum_inv;\n            } else {\n                // causal attention mask. not strictly necessary to set to zero here\n                // only doing this explicitly for debugging and checking to PyTorch\n                att_bth[t2] = 0.0f;\n            }\n        }\n    }\n}\n\n// warp-level reduction for finding the maximum value\n__device__ float warpReduceMax(float val) {\n    for (int offset = 16; offset > 0; offset /= 2) {\n        val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));\n    }\n    return val;\n}\n\n__global__ void softmax_forward_kernel4(float* out, const float* inp, int N, int C) {\n    // out is (N, C) just like inp. Each row of inp will get softmaxed.\n    // same as kernel3, but can handle any block size (multiple of 32)\n    // each row of C elements is handled by block_size threads\n    // furthermore, each block_size threads get executed in warps of 32 threads\n\n    // special reduction operations warpReduceMax/warpReduceSum are used for intra-warp reductions\n    // shared memory is used for inter-warp reduction\n    extern __shared__ float shared[];\n    int idx = blockIdx.x;\n    int tid = threadIdx.x;\n    int warpId = threadIdx.x / 32; // warp index within a block\n    int laneId = threadIdx.x % 32; // thread index within a warp\n\n    // the number of warps per block. recall that blockDim.x is block_size\n    int warpsPerBlock = blockDim.x / 32;\n\n    // shared[] must be allocated to have 2 * warpsPerBlock elements\n    // first half for max values, the second half for sum values\n    float* maxvals = shared;\n    float* sumvals = &shared[warpsPerBlock];\n\n    // one row of inp, i.e. inp[idx, :] of shape (C,)\n    const float* x = inp + idx * C;\n\n    // first, thread coarsening by directly accessing global memory in series\n    float maxval = -INFINITY;\n    for (int i = tid; i < C; i += blockDim.x) {\n        maxval = fmaxf(maxval, x[i]);\n    }\n    // now within-warp reductions for maxval\n    maxval = warpReduceMax(maxval);\n\n    // the 0th thread of each warp writes the maxval of that warp to shared memory\n    if (laneId == 0) maxvals[warpId] = maxval;\n    __syncthreads();\n\n    // now the 0th thread reduces the maxvals in shared memory, i.e. across warps\n    if (tid == 0) {\n        float val = maxvals[tid];\n        for (int i = 1; i < warpsPerBlock; i++) {\n            val = fmaxf(val, maxvals[i]);\n        }\n        // store the final max in the first position\n        maxvals[0] = val;\n    }\n    __syncthreads();\n    // broadcast the max to all threads\n    float offset = maxvals[0];\n\n    // compute expf and write the result to global memory\n    for (int i = tid; i < C; i += blockDim.x) {\n        // subtract max for numerical stability\n        out[idx * C + i] = expf(x[i] - offset);\n    }\n\n    // okay now we calculated exp(x - max(x))\n    // step 2: sum all the values and divide by the sum\n\n    // thread coarsening for sum\n    x = out + idx * C;\n    float sumval = 0.0f;\n    for (int i = tid; i < C; i += blockDim.x) {\n        sumval += x[i];\n    }\n    // within-warp reduction for sumval\n    sumval = warpReduceSum(sumval);\n\n    // write sumval to shared memory\n    if (laneId == 0) sumvals[warpId] = sumval;\n    __syncthreads();\n\n    // inter-thread reduction of sum\n    if (tid == 0) {\n        float val = sumvals[tid];\n        for (int i = 1; i < warpsPerBlock; ++i) {\n            val += sumvals[i];\n        }\n        sumvals[0] = val;\n    }\n    __syncthreads();\n    // broadcast the sum to all threads\n    float sum = sumvals[0];\n\n    // divide the whole row by the sum\n    for (int i = tid; i < C; i += blockDim.x) {\n        out[idx * C + i] = x[i] / sum;\n    }\n}\n\n\n__device__ float& vec_at(float4& vec, int index) {\n    return reinterpret_cast<float*>(&vec)[index];\n}\n\n__device__ float vec_at(const float4& vec, int index) {\n    return reinterpret_cast<const float*>(&vec)[index];\n}\n\n__global__ void softmax_forward_kernel5(float* out, float inv_temperature, const float* inp, int N, int T) {\n    // inp, out shape: (N, T, T), where N = B * NH\n    // fuses the multiplication by scale inside attention\n    // directly autoregressive, so we only compute the lower triangular part\n    // uses the online softmax algorithm\n    assert(T % 4  == 0);\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n    if(idx >= N * T) {\n        return;\n    }\n    int own_pos = idx % T;\n    int pos_by_4 = own_pos / 4;\n\n    // one row of inp, i.e. inp[idx, :] of shape (T,)\n    const float* x = inp + idx * T;\n\n    // not INF, so we don't get NaNs accidentally when subtracting two values.\n    float maxval = -FLT_MAX;\n    float sumval = 0.0f;\n\n    const float4* x_vec = reinterpret_cast<const float4*>(x);\n    for (int i = warp.thread_rank(); i < pos_by_4; i += warp.size()) {\n        float4 v = x_vec[i];\n        float old_maxval = maxval;\n        for(int k = 0; k < 4; ++k) {\n            maxval = fmaxf(maxval, vec_at(v, k));\n        }\n        sumval *= expf(inv_temperature * (old_maxval - maxval));\n        for(int k = 0; k < 4; ++k) {\n            sumval += expf(inv_temperature * (vec_at(v, k) - maxval));\n        }\n    }\n\n    if(4*pos_by_4 + warp.thread_rank() <= own_pos) {\n        float old_maxval = maxval;\n        maxval = fmaxf(maxval, x[4*pos_by_4 + warp.thread_rank()]);\n        sumval *= expf(inv_temperature * (old_maxval - maxval));\n        sumval += expf(inv_temperature * (x[4*pos_by_4 + warp.thread_rank()] - maxval));\n    }\n\n    float global_maxval = cg::reduce(warp, maxval, cg::greater<float>{});\n    sumval *= expf(inv_temperature * (maxval - global_maxval));\n\n    float sum = cg::reduce(warp, sumval, cg::plus<float>{});\n    float norm = 1.f / sum;\n\n    // divide the whole row by the sum\n    for (int i = warp.thread_rank(); i <= own_pos; i += warp.size()) {\n        // recalculation is faster than doing the round-trip through memory.\n        float ev = expf(inv_temperature * (__ldcs(x + i) - global_maxval));\n        __stcs(out + idx * T + i, ev * norm);\n    }\n}\n\n\n__global__ void attention_value_kernel1(float* out, const float* att, const float* inp,\n                                       int B, int T, int C, int NH) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    int total_threads = B * T * NH;\n\n    if (idx < total_threads) {\n        int h = idx % NH;\n        int t = (idx / NH) % T;\n        int b = idx / (NH * T);\n\n        int C3 = C*3;\n        int hs = C / NH; // head size\n\n        float* out_bth = out + b * T * C + t * C + h * hs;\n        const float* att_bth = att + b*NH*T*T + h*T*T + t*T;\n\n        for (int i = 0; i < hs; i++) { out_bth[i] = 0.0f; }\n        for (int t2 = 0; t2 <= t; t2++) {\n           const  float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value\n            float att_btht2 = att_bth[t2];\n            for (int i = 0; i < hs; i++) {\n                out_bth[i] += att_btht2 * value_t2[i];\n            }\n        }\n    }\n}\n\n__global__\nvoid attention_forward_kernel2(\n    const float* Q,\n    const float* K,\n    const float* V,\n    const int N,\n    const int d,\n    const int Tc,\n    const int Tr,\n    const int Bc,\n    const int Br,\n    const float softmax_scale,\n    float* l,\n    float* m,\n    float* O\n) {\n    int tx = threadIdx.x;\n    int bx = blockIdx.x; int by = blockIdx.y;  // batch and head index\n\n    // Offset into Q,K,V,O,l,m - different for each batch and head\n    int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d);  // gridDim.y = nh\n    int lm_offset = (bx * gridDim.y * N) + (by * N);  // offset for l and m\n\n    // Define SRAM for Q,K,V,S\n    extern __shared__ float sram[];\n    int tile_size = Bc * d;  // size of Qi, Kj, Vj\n    float* Qi = sram;\n    float* Kj = &sram[tile_size];\n    float* Vj = &sram[tile_size * 2];\n    float* S = &sram[tile_size * 3];\n\n    for (int j = 0; j < Tc; j++) {\n\n        // Load Kj, Vj to SRAM\n        for (int x = 0; x < d; x++) {\n            Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];\n            Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];\n        }\n        __syncthreads();  // such that the inner loop can use the correct Kj, Vj\n\n        for (int i = 0; i < Tr; i++)  {\n            // if past the end of the sequence, break\n            if (i * Br + tx >= N) {\n                break;\n            }\n\n            // Load Qi to SRAM, l and m to registers\n            for (int x = 0; x < d; x++) {\n                Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];\n            }\n            float row_m_prev = m[lm_offset + (Br * i) + tx];\n            float row_l_prev = l[lm_offset + (Br * i) + tx];\n\n            // S = QK^T, row_m = rowmax(S)\n            // S[tx][y] = Sum_{x = 0}^{d-1} {Qi[tx][x] * Kj[y][x]}\n            // row_m = Max_{y = 0}^{Bc-1} S[tx][y]\n            // with causal masking\n            float row_m = -INFINITY;\n            for (int y = 0; y < Bc; y++) {\n                if (j * Bc + y >= N) {\n                    break;\n                }\n                float sum = 0;\n                for (int x = 0; x < d; x++) {\n                    sum += Qi[(tx * d) + x] * Kj[(y * d) + x];\n                }\n                sum *= softmax_scale;\n                if (i * Br + tx < j * Bc + y)\n                    sum = -INFINITY;\n                S[(Bc * tx) + y] = sum;\n\n                if (sum > row_m)\n                    row_m = sum;\n            }\n\n            // implement softmax with causal masking\n            // P = exp(S - row_m), row_l = rowsum(P)\n            // P[tx][y] = exp(S[tx][y] - row_m)\n            float row_l = 0;\n            for (int y = 0; y < Bc; y++) {\n                if (j * Bc + y >= N) {\n                    break;\n                }\n                if (i * Br + tx < j * Bc + y)\n                    S[(Bc * tx) + y] = 0;\n                else\n                    S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);\n                row_l += S[(Bc * tx) + y];\n            }\n\n            // Compute new m and l\n            float row_m_new = max(row_m_prev, row_m);\n            float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);\n\n            // Write O, l, m to HBM\n            for (int x = 0; x < d; x++) {\n                float pv = 0;  // Pij * Vj\n                for (int y = 0; y < Bc; y++) {\n                    if (j * Bc + y >= N) {\n                        break;\n                    }\n                    pv += S[(Bc * tx) + y] * Vj[(y * d) + x];\n                }\n                O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \\\n                    * ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \\\n                    + (__expf(row_m - row_m_new) * pv));\n            }\n            m[lm_offset + (Br * i) + tx] = row_m_new;\n            l[lm_offset + (Br * i) + tx] = row_l_new;\n        }\n        __syncthreads();  // otherwise, thread can use the wrong Kj, Vj in inner loop\n    }\n}\n\n__global__ void permute_kernel(float* q, float* k, float* v,\n                               const float* inp,\n                               int B, int N, int NH, int d) {\n    // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d)\n    // but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d)\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n    // Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_]\n\n    if (idx < B * NH * N * d) {\n        int b = idx / (NH * N * d);\n        int rest = idx % (NH * N * d);\n        int nh_ = rest / (N * d);\n        rest = rest % (N * d);\n        int n = rest / d;\n        int d_ = rest % d;\n\n        int inp_idx = \\\n            (b * N * 3 * NH * d)\n            +   (n * 3 * NH * d)\n            +       (0 * NH * d)\n            +          (nh_ * d)\n            +                d_;\n\n        q[idx] = inp[inp_idx];\n        k[idx] = inp[inp_idx + NH * d];\n        v[idx] = inp[inp_idx + 2 * (NH * d)];\n    }\n}\n\n__global__ void unpermute_kernel(const float* inp, float *out, int B, int N, int NH, int d) {\n   // out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d)\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n    // out[b][n][nh_][d_] <- inp[b][nh_][n][d_]\n    if (idx < B * NH * N * d) {\n        int b = idx / (NH * N * d);\n        int rest = idx % (NH * N * d);\n        int nh_ = rest / (N * d);\n        rest = rest % (N * d);\n        int n = rest / d;\n        int d_ = rest % d;\n\n        int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;\n        out[other_idx] = inp[idx];\n    }\n}\n\n__global__ void scale_kernel(float* inp, float scale, int B, int NH, int T) {\n    // scales the pre-softmax attention scores by scale\n    // and sets the autoregressive locations to -INFINITY\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < B * NH * T * T) {\n        int rest = idx % (NH * T * T);\n        rest = rest % (T * T);\n        int t2 = rest / T;\n        int t = rest % T;\n        if (t > t2) {\n            inp[idx] = -INFINITY;\n        } else {\n            inp[idx] *= scale;\n        }\n    }\n}\n\n// direct translation of the CPU kernel. Each warp handles ont (b, h, t) combination.\n// The important changes compared to the CPU version:\n//  - each inner loop is handled by a warp\n//  - don't write non-autoregressive parts\n//  - reordered the last loops so that we can do all writing in the outer loop.\n__global__ void attention_forward_fused1(float* out, float* preatt, float* att,\n                                         const float* inp,\n                                         int B, int T, int C, int NH) {\n    // input is (B, T, 3C) Q,K,V\n    // preatt, att are (B, NH, T, T)\n    // output is (B, T, C)\n    int C3 = C*3;\n    int hs = C / NH; // head size\n    float scale = 1.0 / sqrtf(hs);\n\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int t = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n    int h = blockIdx.y;\n    int b = blockIdx.z;\n\n    if(t >= T) return;\n\n    const float* query_t = inp + b * T * C3 + t * C3 + h * hs;\n    float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T;\n    float* att_bth = att + b*NH*T*T + h*T*T + t*T;\n\n    // pass 1: calculate query dot key and maxval\n    float maxval = -INFINITY;\n    for (int t2 = 0; t2 <= t; t2++) {\n        const float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key\n\n        // (query_t) dot (key_t2)\n        float val = 0.0f;\n        for (int i = warp.thread_rank(); i < hs; i += warp.size()) {\n            val += query_t[i] * key_t2[i];\n        }\n        val = cg::reduce(warp, val, cg::plus<float>{});\n        val *= scale;\n        maxval = max(maxval, val);\n        if(warp.thread_rank() == 0) {\n            preatt_bth[t2] = val;\n        }\n    }\n\n    // pass 2: calculate the exp and keep track of sum\n    float expsum = 0.0f;\n    for (int t2 = warp.thread_rank(); t2 <= t; t2 += warp.size()) {\n        float expv = expf(preatt_bth[t2] - maxval);\n        expsum += expv;\n    }\n\n    expsum = cg::reduce(warp, expsum, cg::plus<float>{});\n\n    float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum;\n\n    // pass 3: normalize to get the softmax is combined with the next loop to reduce memory round-trips\n    for (int t2 = warp.thread_rank(); t2 <= t; t2 += warp.size()) {\n        att_bth[t2] = expf(preatt_bth[t2] - maxval) * expsum_inv;\n    }\n\n    // pass 4: accumulate weighted values into the output of attention\n    float* out_bth = out + b * T * C + t * C + h * hs;\n    for (int i = warp.thread_rank(); i < hs; i += warp.size()) {\n        float o = 0.f;\n        for (int t2 = 0; t2 <= t; t2++) {\n            const float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C * 2; // +C*2 because it's value\n            float att_btht2 = att_bth[t2];\n            o += att_btht2 * value_t2[i];\n        }\n        out_bth[i] = o;\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launcher\n\nvoid attention_forward1(float* out, float* preatt, float* att,\n                       const float* inp,\n                       int B, int T, int C, int NH,\n                       const int block_size) {\n    // attention calculation\n    int total_threads = B * NH * T * T;\n    int num_blocks = ceil_div(total_threads, block_size);\n    attention_query_key_kernel1<<<num_blocks, block_size>>>(preatt, inp, B, T, C, NH);\n    // softmax and value accumulation\n    total_threads = B * T * NH;\n    num_blocks = ceil_div(total_threads, block_size);\n    attention_softmax_kernel1<<<num_blocks, block_size>>>(att, preatt, B, T, NH);\n    attention_value_kernel1<<<num_blocks, block_size>>>(out, att, inp, B, T, C, NH);\n}\n\n\nvoid attention_forward2(float* out,\n                       const float* inp,\n                       int B, int T, int C, int NH,\n                       const int block_size) {\n    // TODO there should be no mallocs inside any of these functions!\n    // not fixing this because we don't intend to use attention_forward2,\n    // it seems to be way too slow as is\n\n    // these are hardcoded to 32 for now\n    const int Bc = 32;\n    const int Br = 32;\n    // renaming these to be consistent with the kernel\n    // const int B = B;\n    const int nh = NH;\n    const int N = T;\n    const int d = C / NH;\n    // more\n    const int Tc = ceil((float) N / Bc);\n    const int Tr = ceil((float) N / Br);\n    const float softmax_scale = 1.0 / sqrt(d);\n    // create some temporary memory\n    float* l;\n    float* m;\n    cudaCheck(cudaMalloc(&l, B * nh * N * sizeof(float)));\n    cudaCheck(cudaMalloc(&m, B * nh * N * sizeof(float)));\n    cudaCheck(cudaMemset(l, 0, B * nh * N * sizeof(float)));\n    cudaCheck(cudaMemset(m, -10000.0f, B * nh * N * sizeof(float)));\n\n    // calculate SRAM size needed per block, ensure we have enough shared memory\n    int col_tile_size = Bc * d;  // size of Kj, Vj\n    int row_tile_size = Br * d;  // size of Qi\n    const int sram_size =\n        (2 * col_tile_size * sizeof(float))  // SRAM size for Kj, Vj\n        + (row_tile_size * sizeof(float))  // SRAM size for Qi\n        + (Bc * Br * sizeof(float));  // SRAM size for S\n    int max_sram_size;\n    cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0);\n    if (sram_size > max_sram_size) {\n        printf(\"Max shared memory: %d, requested shared memory: %d \\n\", max_sram_size, sram_size);\n        printf(\"SRAM size exceeds maximum shared memory per block\\n\");\n        printf(\"Try decreasing col_tile_size or row_tile_size further\\n\");\n        exit(1);\n    }\n\n    // grid and block dims\n    dim3 grid_dim(B, nh);  // batch_size x num_heads\n    dim3 block_dim(Br);  // Br threads per block\n\n    // okay so now, this kernel wants Q,K,V to all be of shape (B, nh, N, d)\n    // but instead, we have a single tensor QKV (inp) of shape (B, N, 3, nh, d)\n    // so we have to permute the tensor using a kernel with block_size\n    float *q, *k, *v;\n    cudaCheck(cudaMalloc(&q, B * T * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&k, B * T * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&v, B * T * C * sizeof(float)));\n    int total_threads = B * N * nh * d;\n    int num_blocks = ceil_div(total_threads, block_size);\n    permute_kernel<<<num_blocks, block_size>>>(q, k, v, inp, B, N, nh, d);\n\n    // now actually call the flash attention kernel\n    attention_forward_kernel2<<<grid_dim, block_dim, sram_size>>>(\n        q, k, v,\n        N, d, Tc, Tr, Bc, Br, softmax_scale,\n        l, m, out\n    );\n\n    // out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d)\n    unpermute_kernel<<<num_blocks, block_size>>>(out, q, B, N, nh, d);\n    cudaCheck(cudaMemcpy(out, q, B * T * C * sizeof(float), cudaMemcpyDeviceToDevice));\n\n    // free memory\n    cudaCheck(cudaFree(l));\n    cudaCheck(cudaFree(m));\n    cudaCheck(cudaFree(q));\n    cudaCheck(cudaFree(k));\n    cudaCheck(cudaFree(v));\n}\n\nvoid attention_forward3(float* out, float* vaccum, float* qkvr, float* preatt, float* att,\n                       const float* inp,\n                       int B, int T, int C, int NH,\n                       const int block_size) {\n    // inp is (B, T, 3C) QKV\n    // preatt, att are (B, NH, T, T)\n    // output is (B, T, C)\n    int HS = C / NH; // head size\n\n    // permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS)\n    float *q, *k, *v;\n    q = qkvr + 0 * B * T * C;\n    k = qkvr + 1 * B * T * C;\n    v = qkvr + 2 * B * T * C;\n    int total_threads = B * NH * T * HS;\n    int num_blocks = ceil_div(total_threads, block_size);\n    permute_kernel<<<num_blocks, block_size>>>(q, k, v, inp, B, T, NH, HS);\n\n    // batched matrix multiply with cuBLAS\n    const float alpha = 1.0f;\n    const float beta = 0.0f;\n    cublasCheck(cublasSgemmStridedBatched(cublas_handle,\n                            CUBLAS_OP_T, CUBLAS_OP_N,\n                            T, T, HS,\n                            &alpha,\n                            k, HS, T * HS,\n                            q, HS, T * HS,\n                            &beta,\n                            preatt, T, T * T,\n                            B * NH));\n\n    // multiply all elements of preatt elementwise by scale\n    float scale = 1.0f / sqrtf(HS);\n    total_threads = B * NH * T * T;\n    num_blocks = ceil_div(total_threads, block_size);\n    scale_kernel<<<num_blocks, block_size>>>(preatt, scale, B, NH, T);\n\n    // softmax. preatt is (B, NH, T, T) but we view it as (B * NH * T, T) and use the softmax kernel\n    int softmax_block_size = 256;\n    int grid_size = B * NH * T;\n    size_t shared_mem_size = 2 * softmax_block_size / 32 * sizeof(float);\n    softmax_forward_kernel4<<<grid_size, softmax_block_size, shared_mem_size>>>(att, preatt, B * NH * T, T);\n\n    // new approach: first cuBLAS another batched matmul\n    // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)\n    cublasCheck(cublasSgemmStridedBatched(cublas_handle,\n                            CUBLAS_OP_N, CUBLAS_OP_N,\n                            HS, T, T,\n                            &alpha,\n                            v, HS, T * HS,\n                            att, T, T * T,\n                            &beta,\n                            vaccum, HS, T * HS,\n                            B * NH));\n\n    // now unpermute\n    // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side\n    num_blocks = ceil_div(B * T * C, block_size);\n    unpermute_kernel<<<num_blocks, block_size>>>(vaccum, out, B, T, NH, HS);\n}\n\nvoid attention_forward4(float* out, float* vaccum, float* qkvr, float* preatt, float* att,\n                        const float* inp,\n                        int B, int T, int C, int NH,\n                        const int block_size) {\n    // inp is (B, T, 3C) QKV\n    // preatt, att are (B, NH, T, T)\n    // output is (B, T, C)\n    int HS = C / NH; // head size\n\n    // permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS)\n    float *q, *k, *v;\n    q = qkvr + 0 * B * T * C;\n    k = qkvr + 1 * B * T * C;\n    v = qkvr + 2 * B * T * C;\n    int total_threads = B * NH * T * HS;\n    int num_blocks = ceil_div(total_threads, block_size);\n    permute_kernel<<<num_blocks, block_size>>>(q, k, v, inp, B, T, NH, HS);\n\n    // batched matrix multiply with cuBLAS\n    const float alpha = 1.0f;\n    const float beta = 0.0f;\n\n    cublasCheck(cublasSgemmStridedBatched(cublas_handle,\n                                     CUBLAS_OP_T, CUBLAS_OP_N,\n                                     T, T, HS,\n                                     &alpha,\n                                     k, HS, T * HS,\n                                     q, HS, T * HS,\n                                     &beta,\n                                     preatt, T, T * T,\n                                     B * NH));\n\n    // multiply all elements of preatt elementwise by scale\n    float scale = 1.0 / sqrtf(HS);\n    int softmax_block_size = 256;\n    int grid_size = ceil_div(B * NH * T * 32, softmax_block_size);\n    softmax_forward_kernel5<<<grid_size, softmax_block_size>>>(att, scale, preatt, B * NH, T);\n\n    // new approach: first cuBLAS another batched matmul\n    // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)\n    cublasCheck(cublasSgemmStridedBatched(cublas_handle,\n                                     CUBLAS_OP_N, CUBLAS_OP_N,\n                                     HS, T, T,\n                                     &alpha,\n                                     v, HS, T * HS,\n                                     att, T, T * T,\n                                     &beta,\n                                     vaccum, HS, T * HS,\n                                     B * NH));\n\n    // now unpermute\n    // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side\n    num_blocks = ceil_div(B * T * C, block_size);\n    unpermute_kernel<<<num_blocks, block_size>>>(vaccum, out, B, T, NH, HS);\n}\n\n\n__global__ void softmax_forward_kernel5_lowp(floatX* out, float inv_temperature,\n                                             const floatX* inp, int N, int T) {\n    // inp, out shape: (N, T, T), where N = B * NH\n    // fuses the multiplication by scale inside attention\n    // directly autoregressive, so we only compute the lower triangular part\n    // uses the online softmax algorithm\n    assert(T % 4  == 0);\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n    if(idx >= N * T) {\n        return;\n    }\n    int own_pos = idx % T;\n    int pos_by_4 = own_pos / 4;\n\n    // one row of inp, i.e. inp[idx, :] of shape (T,)\n    const floatX* x = inp + idx * T;\n\n    // not INF, so we don't get NaNs accidentally when subtracting two values.\n    float maxval = -FLT_MAX;\n    float sumval = 0.0f;\n\n    // Same thing but without float4, one at a time\n    for (int i = warp.thread_rank(); i < pos_by_4; i += warp.size()) {\n        float old_maxval = maxval;\n        for(int k = 0; k < 4; ++k) {\n            maxval = fmaxf(maxval, (float)x[4*i + k]);\n        }\n        sumval *= expf(inv_temperature * (old_maxval - maxval));\n        for(int k = 0; k < 4; ++k) {\n            sumval += expf(inv_temperature * ((float)x[4*i + k] - maxval));\n        }\n    }\n\n    if(4*pos_by_4 + warp.thread_rank() <= own_pos) {\n        float old_maxval = maxval;\n        maxval = fmaxf(maxval, (float)x[4*pos_by_4 + warp.thread_rank()]);\n        sumval *= expf(inv_temperature * (old_maxval - maxval));\n        sumval += expf(inv_temperature * ((float)x[4*pos_by_4 + warp.thread_rank()] - maxval));\n    }\n\n    float global_maxval = cg::reduce(warp, maxval, cg::greater<float>{});\n    sumval *= expf(inv_temperature * (maxval - global_maxval));\n\n    float sum = cg::reduce(warp, sumval, cg::plus<float>{});\n    float norm = 1.f / sum;\n\n    // divide the whole row by the sum\n    for (int i = warp.thread_rank(); i <= own_pos; i += warp.size()) {\n        // recalculation is faster than doing the round-trip through memory.\n        float ev = expf(inv_temperature * ((float)__ldcs(x + i) - global_maxval));\n        __stcs(out + idx * T + i, (floatX)(ev * norm));\n    }\n}\n\n__global__ void permute_kernel_lowp(floatX* q, floatX* k, floatX* v,\n                                    const float* inp,\n                                    int B, int N, int NH, int d) {\n    // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d)\n    // but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d)\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n    // Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_]\n    if (idx < B * NH * N * d) {\n        int b = idx / (NH * N * d);\n        int rest = idx % (NH * N * d);\n        int nh_ = rest / (N * d);\n        rest = rest % (N * d);\n        int n = rest / d;\n        int d_ = rest % d;\n\n        int inp_idx = \\\n            (b * N * 3 * NH * d)\n            +   (n * 3 * NH * d)\n            +       (0 * NH * d)\n            +          (nh_ * d)\n            +                d_;\n\n        q[idx] = (floatX)inp[inp_idx];\n        k[idx] = (floatX)inp[inp_idx + NH * d];\n        v[idx] = (floatX)inp[inp_idx + 2 * (NH * d)];\n    }\n}\n\n__global__ void unpermute_kernel_lowp(const floatX* inp, float *out, int B, int N, int NH, int d) {\n   // out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d)\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n    // out[b][n][nh_][d_] <- inp[b][nh_][n][d_]\n    if (idx < B * NH * N * d) {\n        int b = idx / (NH * N * d);\n        int rest = idx % (NH * N * d);\n        int nh_ = rest / (N * d);\n        rest = rest % (N * d);\n        int n = rest / d;\n        int d_ = rest % d;\n\n        int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;\n        out[other_idx] = (float)inp[idx];\n    }\n}\n\nvoid attention_forward5(float* out, floatX* vaccum, floatX* qkvr, floatX* preatt, floatX* att,\n                        const float* inp,\n                        int B, int T, int C, int NH,\n                        const int block_size, bool skip_permute=false) {\n    // FP16 version of kernel 4 (with permute/unpermute doing FP32<->FP16)\n    // That permute can be skipped on perf runs to analyse its performance impact\n    // inp is (B, T, 3C) QKV\n    // preatt, att are (B, NH, T, T)\n    // output is (B, T, C)\n\n    // permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS)\n    int HS = C / NH; // head size\n    floatX *q = qkvr + 0 * B * T * C;\n    floatX *k = qkvr + 1 * B * T * C;\n    floatX* v = qkvr + 2 * B * T * C;\n\n    int total_threads = B * NH * T * HS;\n    int num_blocks = ceil_div(total_threads, block_size);\n    if (!skip_permute || first_run_validation) {\n        permute_kernel_lowp<<<num_blocks, block_size>>>(q, k, v, inp, B, T, NH, HS);\n    }\n\n    // IMPORTANT: alpha/beta are FP32 for CUBLAS_COMPUTE_32F even if FP16 inputs/outputs\n    // But need FP16 scale for CUBLAS_COMPUTE_16F (no errors otherwise, just garbage results *sigh*)\n    const float alpha = 1.0f;\n    const float beta = 0.0f;\n    const floatX alpha_lowp = (floatX)alpha;\n    const floatX beta_lowp = (floatX)beta;\n    void* alpha_ptr = CUBLAS_LOWP_COMPUTE == CUBLAS_COMPUTE_16F ? (void*)&alpha_lowp : (void*)&alpha;\n    void* beta_ptr = CUBLAS_LOWP_COMPUTE == CUBLAS_COMPUTE_16F ? (void*)&beta_lowp : (void*)&beta;\n\n    // batched matrix multiply with cuBLAS\n    cublasCheck(cublasGemmStridedBatchedEx(cublas_handle,\n                                     CUBLAS_OP_T, CUBLAS_OP_N,\n                                     T, T, HS,\n                                     alpha_ptr,\n                                     k, CUBLAS_LOWP, HS, T * HS,\n                                     q, CUBLAS_LOWP, HS, T * HS,\n                                     beta_ptr,\n                                     preatt, CUBLAS_LOWP, T, T * T,\n                                     B * NH,\n                                     CUBLAS_LOWP_COMPUTE,\n                                     CUBLAS_GEMM_DEFAULT));\n\n    // multiply all elements of preatt elementwise by scale\n    float scale = 1.0f / sqrtf(HS);\n    int softmax_block_size = 256;\n    int grid_size = ceil_div(B * NH * T * 32, softmax_block_size);\n    softmax_forward_kernel5_lowp<<<grid_size, softmax_block_size>>>(att, scale, preatt, B * NH, T);\n\n    // new approach: first cuBLAS another batched matmul\n    // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)\n    cublasCheck(cublasGemmStridedBatchedEx(cublas_handle,\n                                     CUBLAS_OP_N, CUBLAS_OP_N,\n                                     HS, T, T,\n                                     alpha_ptr,\n                                     v, CUBLAS_LOWP, HS, T * HS,\n                                     att, CUBLAS_LOWP, T, T * T,\n                                     beta_ptr,\n                                     vaccum, CUBLAS_LOWP, HS, T * HS,\n                                     B * NH,\n                                     CUBLAS_LOWP_COMPUTE,\n                                     CUBLAS_GEMM_DEFAULT));\n\n    // now unpermute\n    // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side\n    num_blocks = ceil_div(B * T * C, block_size);\n    if(!skip_permute || first_run_validation) {\n        unpermute_kernel_lowp<<<num_blocks, block_size>>>(vaccum, out, B, T, NH, HS);\n    }\n}\n\n#ifdef ENABLE_CUDNN\nusing graph_tensors_fwd = std::tuple<std::shared_ptr<fe::graph::Graph>,\n                                     std::shared_ptr<fe::graph::Tensor_attributes>,  // Q,\n                                     std::shared_ptr<fe::graph::Tensor_attributes>,  // K,\n                                     std::shared_ptr<fe::graph::Tensor_attributes>,  // V,\n                                     std::shared_ptr<fe::graph::Tensor_attributes>,  // Attn_scale,\n                                     std::shared_ptr<fe::graph::Tensor_attributes>,  // O\n                                     std::shared_ptr<fe::graph::Tensor_attributes>>; // Stats\n\n// Need a cache because graph->build_operation_graph() is slow but everything else seems fast\nusing cache_type_fwd = std::unordered_map<std::size_t, graph_tensors_fwd>;\n\n// Loosely based on cuDNN frontend samples functions and massively simplified\ntemplate <typename... Args>\nauto lookup_cache_or_build_graph_fwd(Args... args) {\n    static cache_type_fwd user_maintained_cache_fwd;\n    auto [B, H, T, HS, is_inference_only] = std::make_tuple(args...);\n\n    auto graph = std::make_shared<fe::graph::Graph>();\n    graph->set_io_data_type(CUDNN_16BIT)\n          .set_intermediate_data_type(fe::DataType_t::FLOAT)\n          .set_compute_data_type(fe::DataType_t::FLOAT);\n\n    // QKV is (B, T, 3, NH, HS) which cuDNN can handle directly without an external permute\n    auto Q = graph->tensor(fe::graph::Tensor_attributes()\n                               .set_name(\"Q\")\n                               .set_dim({B, H, T, HS})\n                               .set_stride({3 * H * HS * T,  HS, 3 * H * HS, 1}));\n    auto K = graph->tensor(fe::graph::Tensor_attributes()\n                               .set_name(\"K\")\n                               .set_dim({B, H, T, HS})\n                               .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1}));\n    auto V = graph->tensor(fe::graph::Tensor_attributes()\n                               .set_name(\"V\")\n                               .set_dim({B, H, T, HS})\n                               .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1}));\n    auto attn_scale = graph->tensor(fe::graph::Tensor_attributes()\n                                .set_name(\"attn_scale\")\n                                .set_dim({1, 1, 1, 1})\n                                .set_stride({1, 1, 1, 1})\n                                .set_is_pass_by_value(true)\n                                .set_data_type(fe::DataType_t::FLOAT));\n\n    auto sdpa_options = fe::graph::SDPA_attributes().set_name(\"flash_attention\");\n    sdpa_options.set_is_inference(is_inference_only);\n    sdpa_options.set_attn_scale(attn_scale);\n    sdpa_options.set_causal_mask(true);\n\n    // Create the graph operation and get the output tensors back\n    auto [O, stats] = graph->sdpa(Q, K, V, sdpa_options);\n\n    // Output is (B, T, NH, HS) BF16/FP16 and stats for backward pass is (B, NH, T) FP32\n    O->set_output(true).set_dim({B, H, T, HS}).set_stride({H * HS * T, HS, H * HS, 1});\n\n    assert(stats == nullptr || is_inference_only == false);\n    if (is_inference_only == false) {\n        stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)\n                               .set_dim({B, H, T, 1})\n                               .set_stride({H * T, T, 1, 1});\n    }\n\n    assert(graph->validate().is_good());\n    auto key = graph->key();\n    auto it = user_maintained_cache_fwd.find(key);\n    if (it != user_maintained_cache_fwd.end()) {\n        return it->second;\n    }\n\n    // Build the operation graph and execution part (this is the VERY SLOW PART)\n    assert(graph->build_operation_graph(cudnn_handle).is_good());\n    auto plans = graph->create_execution_plans({fe::HeurMode_t::A});\n    assert(graph->check_support(cudnn_handle).is_good());\n    assert(graph->build_plans(cudnn_handle).is_good());\n\n    auto tuple = std::make_tuple(graph, Q, K, V, attn_scale, O, stats);\n    user_maintained_cache_fwd.insert({key, tuple});\n    return tuple;\n}\n\n// Used on first run only so we can validate against the CPU results\n__global__ void fp32_to_lowp_kernel(floatX* out, const float* inp) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    out[idx] = (floatX)inp[idx];\n}\n\n__global__ void lowp_to_fp32_kernel(const floatX* inp, float *out) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    out[idx] = (float)inp[idx];\n}\n\nvoid attention_forward_cudnn(floatX* out,  // output: (B, T, NH, HS)\n                             float* stats, // output for backward pass: (B, NH, T)\n                             floatX* inp,  // input: (B, T, 3, NH, HS) QKV\n                             float* in_fp32,  // fp32 input\n                             float* out_fp32, // fp32 output for validation\n                             int B, int T, int C, int NH) {\n    static bool first_run_validation = true;\n    int HS = C / NH; // number of features per head\n    bool is_inference_only = (stats == nullptr);\n\n    // Convert from FP32 to FP16/BF16 on 1st run to get correct results\n    const int block_size = 64; // smallest full occupancy block size on modern GPUs\n    if (first_run_validation) {\n        int total_threads = B * T * C * 3;\n        assert(total_threads % block_size == 0);\n        int num_blocks = total_threads / block_size;\n        fp32_to_lowp_kernel<<<num_blocks, block_size>>>(inp, in_fp32);\n    }\n\n    // Get graph and tensors from cache (or generate it on first use)\n    auto [graph, Q, K, V, attn_scale, O, softmax_stats] =\n        lookup_cache_or_build_graph_fwd(B, NH, T, HS, is_inference_only);\n\n    // Prepare all the tensor pointers for executing the graph\n    void* devPtrQ = inp;\n    void* devPtrK = (inp + C);\n    void* devPtrV = (inp + 2 * C);\n    float attn_scale_cpu = 1.0 / sqrtf(HS);\n    void* devPtrO = out;\n\n    // Build variant pack\n    std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {\n        {Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {attn_scale, &attn_scale_cpu}, {O, devPtrO}};\n\n    // Add the stats tensor unless we are only doing inference (only needed for backward pass)\n    if (is_inference_only == false) {\n        variant_pack[softmax_stats] = stats;\n    }\n\n    // Reallocate the workspace if the required size is greater than the current workspace\n    // By default, cuDNN uses up to 256MiB of workspace, so we don't want to just allocate the maximum\n    if (graph->get_workspace_size() > cudnn_workspace_size) {\n        if (cudnn_workspace_size > 0) {\n            cudaCheck(cudaFree(cudnn_workspace));\n        }\n        cudnn_workspace_size = graph->get_workspace_size();\n        cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size));\n    }\n\n    // Execute graph\n    assert(graph->execute(cudnn_handle, variant_pack, cudnn_workspace).is_good());\n    cudaCheck(cudaGetLastError());\n\n    // Optionally convert back from FP16/BF16 to FP32\n    if (first_run_validation) {\n        int total_threads = B * T * C;\n        assert(total_threads % block_size == 0);\n        int num_blocks = total_threads / block_size;\n        lowp_to_fp32_kernel<<<num_blocks, block_size>>>(out, out_fp32);\n    }\n    cudaCheck(cudaGetLastError());\n    first_run_validation = false;\n}\n\n#endif // ENABLE_CUDNN\n\n// kernel version dispatch\nvoid attention_forward(int kernel_num,\n                       float* out, float* stats, float* vaccum,\n                       float* qkvr, float* preatt, float* att,\n                       float* inp,\n                       int B, int T, int C, int NH,\n                       const int block_size) {\n    switch (kernel_num) {\n        case 1:\n            attention_forward1(out, preatt, att, inp, B, T, C, NH, block_size);\n            break;\n        case 2:\n            attention_forward2(out, inp, B, T, C, NH, block_size);\n            break;\n        case 3:\n            attention_forward3(out, vaccum, qkvr, preatt, att, inp, B, T, C, NH, block_size);\n            break;\n        case 4:\n            attention_forward4(out, vaccum, qkvr, preatt, att, inp, B, T, C, NH, block_size);\n            break;\n        case 5:\n            attention_forward5(out, (floatX*)vaccum, (floatX*)qkvr,\n                               (floatX*)preatt, (floatX*)att,\n                               inp, B, T, C, NH, block_size, false);\n            break;\n        case 6: // skip permutes for perf passes (to analyse perf as if in/out were truly 16-bit)\n            attention_forward5(out, (floatX*)vaccum, (floatX*)qkvr,\n                               (floatX*)preatt, (floatX*)att,\n                               inp, B, T, C, NH, block_size, true);\n            break;\n        #ifdef ENABLE_CUDNN\n        case 10:\n            // note: validation only cares about out, which is out_fp32 of the function\n            // inp is hackily converted to FP16 into qkvr only on the first run\n            // similarly, vaccum is converted to FP32 into out only on the first run\n            attention_forward_cudnn((floatX*)vaccum, stats, (floatX*)qkvr, inp, out, B, T, C, NH);\n            break;\n        #endif\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n// ----------------------------------------------------------------------------\n\nint main(int argc, char **argv) {\n    setup_main();\n\n    int B = 8;\n    int T = 1024;\n    int C = 768;\n    int NH = 12;\n\n    int deviceIdx = 0;\n    cudaCheck(cudaSetDevice(deviceIdx));\n    cudaDeviceProp deviceProp;\n    cudaGetDeviceProperties(&deviceProp, deviceIdx);\n\n    // setup cuBLAS (and cuDNN if needed)\n    cublasCreate(&cublas_handle);\n    int enable_tf32 = deviceProp.major >= 8 ? 1 : 0;\n    printf(\"enable_tf32: %d\\n\", enable_tf32);\n    cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH;\n    cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode));\n\n    #ifdef ENABLE_CUDNN\n    checkCudnnErr(cudnnCreate(&cudnn_handle));\n    #endif\n\n    // create host memory of random numbers\n    float* out = (float*)malloc(B * T * C * sizeof(float));\n    float* preatt = (float*)malloc(B * NH * T * T * sizeof(float));\n    float* att = (float*)malloc(B * NH * T * T * sizeof(float));\n    //float* inp = make_random_float(B * T * 3 * C, 10.0f);\n    float* inp = make_random_float(B * T * 3 * C);\n\n    // move to GPU\n    float* d_out;\n    float* d_stats; // for cuDNN\n    float* d_vaccum;\n    float* d_qkvr;\n    float* d_preatt;\n    float* d_att;\n    float* d_inp;\n    cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_stats, B * NH * T * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_vaccum, B * T * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_qkvr, B * T * 3 * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_preatt, B * NH * T * T * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_att, B * NH * T * T * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_inp, B * T * 3 * C * sizeof(float)));\n    cudaCheck(cudaMemcpy(d_inp, inp, B * T * 3 * C * sizeof(float), cudaMemcpyHostToDevice));\n\n    // read kernel_num from command line\n    int kernel_num = 1;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n    int block_sizes[] = {32, 64, 128, 256, 512};\n\n    // Lower accuracy requirements for FP16 (1e-4f also too much for TF32 on kernels 3 & 4)\n    float accuracy_threshold = (kernel_num <= 4) ? 1e-3f : 1e-2f;\n\n    // first check the correctness of the kernel\n    attention_forward_cpu(out, preatt, att, inp, B, T, C, NH);\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        printf(\"Checking block size %d.\\n\", block_size);\n        attention_forward(kernel_num, d_out, d_stats, d_vaccum, d_qkvr, d_preatt, d_att, d_inp, B, T, C, NH, block_size);\n        // all kernels should produce the correct output out\n        // todo - make accuracy threshold dynamic and depend on FP16 vs FP32?\n        validate_result(d_out, out, \"out\", B * T * C, accuracy_threshold);\n        // but as for preatt and att, things get a bit more complicated:\n        if (kernel_num != 2 && kernel_num < 5) {\n            // kernel 2 (knowingly) fails att/preatt because it uses a different algorithm\n            // that estimates the softmax online and never materializes preatt/att\n            validate_result(d_att, att, \"att\", B * NH * T * T, accuracy_threshold);\n        }\n        if (kernel_num != 2 && kernel_num < 4) {\n            // kernel 4 (knowingly) fails preatt because it fuses the scale normalization\n            // into the softmax, so preatt is off by 1.0f / sqrt(HS)\n            // but att and out (checked below) should match.\n            validate_result(d_preatt, preatt, \"preatt\", B * NH * T * T, accuracy_threshold);\n        }\n    }\n    printf(\"All results match. Starting benchmarks.\\n\\n\");\n    first_run_validation = false;\n\n    // benchmark speed of the kernel\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        int repeat_times = 100;\n\n        float elapsed_time = benchmark_kernel(repeat_times, attention_forward,\n                                              kernel_num, d_out, d_stats, d_vaccum, d_qkvr, d_preatt, d_att,\n                                              d_inp, B, T, C, NH, block_size);\n\n        printf(\"block_size %4d | time %f ms\\n\", block_size, elapsed_time);\n    }\n\n    // free memory\n    free(out);\n    free(preatt);\n    free(att);\n    free(inp);\n    cudaCheck(cudaFree(d_out));\n    cudaCheck(cudaFree(d_vaccum));\n    cudaCheck(cudaFree(d_qkvr));\n    cudaCheck(cudaFree(d_preatt));\n    cudaCheck(cudaFree(d_att));\n    cudaCheck(cudaFree(d_inp));\n    cudaCheck(cudaFree(d_stats));\n    cublasDestroy(cublas_handle);\n\n    #ifdef ENABLE_CUDNN\n    cudnnDestroy(cudnn_handle);\n    if (cudnn_workspace_size > 0) {\n        cudaCheck(cudaFree(cudnn_workspace));\n    }\n    #endif\n\n    return 0;\n}"
  },
  {
    "path": "dev/cuda/benchmark_on_modal.py",
    "content": "\"\"\"\nScript for running benchmarks on the Modal platform.\nThis is useful for folks who do not have access to expensive GPUs locally.\nExample usage for cuda kernels:\nGPU_MEM=80 modal run benchmark_on_modal.py \\\n    --compile-command \"nvcc -O3 --use_fast_math attention_forward.cu -o attention_forward -lcublas\" \\\n    --run-command \"./attention_forward 1\"\nOR if you want to use cuDNN etc.\n\n\nFor training the gpt2 model with cuDNN use:\nGPU_MEM=80 modal run dev/cuda/benchmark_on_modal.py \\\n    --compile-command \"make train_gpt2cu USE_CUDNN=1\"\n    --run-command \"./train_gpt2cu -i dev/data/tinyshakespeare/tiny_shakespeare_train.bin -j dev/data/tinyshakespeare/tiny_shakespeare_val.bin -v 250 -s 250 -g 144 -f shakespeare.log -b 4\"\n\n\nFor profiling using nsight system:\nGPU_MEM=80 modal run dev/cuda/benchmark_on_modal.py \\\n    --compile-command \"make train_gpt2cu USE_CUDNN=1\" \\\n    --run-command \"nsys profile --cuda-graph-trace=graph --python-backtrace=cuda --cuda-memory-usage=true \\\n    ./train_gpt2cu -i dev/data/tinyshakespeare/tiny_shakespeare_train.bin \\\n    -j dev/data/tinyshakespeare/tiny_shakespeare_val.bin -v 250 -s 250 -g 144 -f shakespeare.log -b 4\"\n\nFor more nsys profiling specifics and command options, take a look at: https://docs.nvidia.com/nsight-systems/2024.2/UserGuide/\n-> To profile the report using a GUI, download NVIDIA NSight System GUI version (this software can run on all OS, so you download it locally)\n\nNOTE: Currently there is a bug in the profiling using nsight system which produces a unrecognized GPU UUId error on the command line but it\ndoes not actually interfere with the model training and validation. The report (that you download) is still generated and can be viewed from Nsight Systems\n\"\"\"\nimport subprocess\nimport os\nimport sys\nimport datetime\n\nimport modal\nfrom modal import Image, Stub\nGPU_NAME_TO_MODAL_CLASS_MAP = {\n    \"H100\": modal.gpu.H100,\n    \"A100\": modal.gpu.A100,\n    \"A10G\": modal.gpu.A10G,\n}\nN_GPUS = int(os.environ.get(\"N_GPUS\", 1))\nGPU_MEM = int(os.environ.get(\"GPU_MEM\", 40))\nGPU_NAME = os.environ.get(\"GPU_NAME\", \"A100\")\nGPU_CONFIG = GPU_NAME_TO_MODAL_CLASS_MAP[GPU_NAME](count=N_GPUS, size=str(GPU_MEM) + 'GB')\n\nAPP_NAME = \"llm.c benchmark run\"\n\nimage = (\n    Image.from_registry(\"totallyvyom/cuda-env:latest-2\")\n    .pip_install(\"huggingface_hub==0.20.3\", \"hf-transfer==0.1.5\")\n    .env(\n        dict(\n            HUGGINGFACE_HUB_CACHE=\"/pretrained\",\n            HF_HUB_ENABLE_HF_TRANSFER=\"1\",\n            TQDM_DISABLE=\"true\",\n        )\n    )\n    .run_commands(\n    \"wget -q https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-Linux-x86_64.sh\",\n    \"bash cmake-3.28.1-Linux-x86_64.sh --skip-license --prefix=/usr/local\",\n    \"rm cmake-3.28.1-Linux-x86_64.sh\",\n    \"ln -s /usr/local/bin/cmake /usr/bin/cmake\",)\n    .run_commands(\n        \"apt-get install -y --allow-change-held-packages libcudnn8 libcudnn8-dev\",\n        \"apt-get install -y openmpi-bin openmpi-doc libopenmpi-dev kmod sudo\",\n        \"git clone https://github.com/NVIDIA/cudnn-frontend.git /root/cudnn-frontend\",\n        \"cd /root/cudnn-frontend && mkdir build && cd build && cmake .. && make\"\n    )\n    .run_commands(\n        \"wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin && \\\n        mv cuda-ubuntu2204.pin /etc/apt/preferences.d/cuda-repository-pin-600 && \\\n        apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub && \\\n        add-apt-repository \\\"deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /\\\" && \\\n        apt-get update\"\n    ).run_commands(\n        \"apt-get install -y nsight-systems-2023.3.3\"\n    )\n)\n\nstub = modal.App(APP_NAME)\n\ndef execute_command(command: str):\n    command_args = command.split(\" \")\n    print(f\"{command_args = }\")\n    subprocess.run(command_args, stdout=sys.stdout, stderr=subprocess.STDOUT)\n\n@stub.function(\n    gpu=GPU_CONFIG,\n    image=image,\n    allow_concurrent_inputs=4,\n    container_idle_timeout=900,\n    mounts=[modal.Mount.from_local_dir(\"./\", remote_path=\"/root/\")],\n    # Instead of 'cuda-env' put your volume name that you create from 'modal volume create {volume-name}'\n    # This enables the profiling reports to be saved on the volume that you can download by using:\n    # 'modal volume get {volume-name} {/output_file_name}\n    # For example right now, when profiling using this command \"nsys profile --trace=cuda,nvtx --cuda-graph-trace=graph --python-backtrace=cuda --cuda-memory-usage=true\" you would get your report\n    # using in a directory in your volume, where the name contains the timestamp unique id.\n    # This script will generate a \"report1_{timestamp} folder in volume\"\n    # and you can download it with 'modal volume get {volume-name} report1_{timestamp}\n    volumes={\"/cuda-env\": modal.Volume.from_name(\"cuda-env\")},\n)\ndef run_benchmark(compile_command: str, run_command: str):\n    execute_command(\"pwd\")\n    execute_command(\"ls\")\n    execute_command(compile_command)\n    execute_command(run_command)\n    # Use this section if you want to profile using nsight system and install the reports on your volume to be locally downloaded\n    timestamp = datetime.datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n\n    execute_command(\"mkdir report1_\" + timestamp)\n    execute_command(\"mv /root/report1.nsys-rep /root/report1_\" + timestamp + \"/\")\n    execute_command(\"mv /root/report1.qdstrm /root/report1_\" + timestamp + \"/\")\n    execute_command(\"mv /root/report1_\" + timestamp + \"/\" + \" /cuda-env/\")\n\n    return None\n\n@stub.local_entrypoint()\ndef inference_main(compile_command: str, run_command: str):\n    results = run_benchmark.remote(compile_command, run_command)\n    return results"
  },
  {
    "path": "dev/cuda/classifier_fused.cu",
    "content": "/*  Kernels for fused forward/backward classifier part\nThis fuses softmax, crossentropy, and logit gradients into a single pass, so we don't have to write unnecessary\n(B, T, V) tensors. Such an operation is only possible if `dloss` can be known beforehand, which doesn't seem like\nmuch of a restriction: In pretraining, it is just a constant 1/batch_size tensor, for fine-tuning we might zero\nout the input prompt, but that is known in advance.\n\nCompile example:\nnvcc -O3 --use_fast_math -lcublas -lcublasLt classifier_fused.cu -o classifier_fused\n\n./classifier_fused 1\n./classifier_fused 2\n./classifier_fused 3\n./classifier_fused 4\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <float.h>\n#include <cuda_runtime.h>\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n#include \"common.h\"\n\n// todo - this file does not properly support anything but FP32\n// kernel 5 can be run in fp16/bf16 to test performance, but the outputs will be wrong\n#if defined(ENABLE_BF16)\ntypedef __nv_bfloat16 floatX;\n#elif defined(ENABLE_FP16)\ntypedef half floatX;\n#else\ntypedef float floatX;\n#endif\ntypedef Packed128<floatX> x128;\n\n// ----------------------------------------------------------------------------\n// CPU code reference\n\nvoid softmax_forward_cpu(float* out, const float* inp, int N, int C) {\n    // inp is (N, C)\n    // out is (N, C), each row of inp will get softmaxed\n    for (int64_t i = 0; i < N; i++) {\n        const float* inp_row = inp + i * C;\n        float* out_row = out + i * C;\n\n        float maxval = -INFINITY;\n        for (int j = 0; j < C; j++) {\n            if (inp_row[j] > maxval) {\n                maxval = inp_row[j];\n            }\n        }\n        double sum = 0.0;\n        for (int j = 0; j < C; j++) {\n            out_row[j] = expf(inp_row[j] - maxval);\n            sum += out_row[j];\n        }\n        for (int j = 0; j < C; j++) {\n            out_row[j] /= sum;\n        }\n    }\n}\n\n\nvoid crossentropy_forward_cpu(float* losses,\n                              const float* probs, const int* targets,\n                              int B, int T, int V) {\n    // output: losses is (B,T) of the individual losses at each position\n    // input: probs are (B,T,V) of the probabilities\n    // input: targets is (B,T) of integers giving the correct index in logits\n    for (int64_t bt = 0; bt < B * T; bt++) {\n        // loss = -log(probs[target])\n        const float* probs_bt = probs + bt * V;\n        int ix = targets[bt];\n        losses[bt] = -logf(probs_bt[ix]);\n    }\n}\n\nvoid crossentropy_softmax_backward_cpu(float* dlogits,\n                                       const float* dlosses, const float* probs, const int* targets,\n                                       int B, int T, int V) {\n    // backwards through both softmax and crossentropy\n    for (int64_t bt = 0; bt < B * T; bt++) {\n        float* dlogits_bt = dlogits + bt * V;\n        const float* probs_bt = probs + bt * V;\n        float dloss = dlosses[bt];\n        int ix = targets[bt];\n        for (int i = 0; i < V; i++) {\n            float p = probs_bt[i];\n            float indicator = i == ix ? 1.0f : 0.0f;\n            dlogits_bt[i] = (p - indicator) * dloss;\n        }\n    }\n}\n\n// ----------------------------------------------------\n// Kernel Utils\n\n// warp-level reduction for finding the maximum value\n__device__ float warpReduceMax(float val) {\n    for (int offset = 16; offset > 0; offset /= 2) {\n        val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, offset));\n    }\n    return val;\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n\nstruct SoftmaxParams {\n    float Scale;\n    float Offset;\n};\nnamespace cg = cooperative_groups;\n__device__ SoftmaxParams prepare_softmax(cg::thread_block_tile<32>& warp,\n                                         int64_t idx, const float* inp, int V, int P) {\n    // this warp (of 32) threads processes one row of inp, i.e. inp[idx, :] of shape (V,)\n    // note that inp is actually (B * T, P) but we only use the first V elements\n    // this function then calculates:\n    // 1) the max value to subtract for numerical stability and\n    // 2) the sum normalization factor\n    const float* x = inp + idx * P;\n    // thread coarsening loop, where the 32 threads serially process all V elements\n    // thread_rank() is in [0, 31], warp.size() is 32\n    float maxval = -INFINITY;\n    float sumval = 0.0f;\n    for (int i = warp.thread_rank(); i < V; i += warp.size()) {\n        float v = x[i];\n        float old_maxval = maxval;\n        // online softmax recurrence from \"Online normalizer calculation for softmax\" paper\n        maxval = fmaxf(maxval, v);\n        sumval *= expf((old_maxval - maxval));\n        sumval += expf(v - maxval);\n    }\n    // warp-level reduction to get the maxval across the 32 threads\n    float global_maxval = cg::reduce(warp, maxval, cg::greater<float>{});\n    // all 32 threads do a final shift of the sum considering the global max in this row\n    sumval *= expf((maxval - global_maxval));\n    // warp-level reduction to get the sumval across the 32 threads\n    float global_sumval = cg::reduce(warp, sumval, cg::plus<float>{});\n    // the final normalization factor\n    float norm = 1.0f / global_sumval;\n    return SoftmaxParams{norm, global_maxval};\n}\n\n__global__ void fused_classifier_kernel1(float* dlogits, float* losses,\n                             const float* logits, const float* dlosses, const int* targets,\n                             int B, int T, int V, int P) {\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    // example: B = 4, T = 1024, block_size = 128 => we'd have grid_size = 1024\n    // each block of 4 warps is in charge of 4 rows of the input, one warp per row\n    // meta_group_size is the number of warps per block (e.g. 4)\n    // meta_group_rank is the index of the warp in the block (e.g. 0, 1, 2, 3)\n    int64_t idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n    if (idx >= B * T) { // there are B * T rows in the input\n        return;\n    }\n    int b = idx / T;\n    int t = idx % T;\n\n    // calculate the offset (maxval) and scale (sumval) for the softmax\n    SoftmaxParams sp = prepare_softmax(warp, idx, logits, V, P);\n\n    // in each row (handled by one warp), thread 0 calculates the loss\n    // calculate the probability needed for the loss and update losses\n    if(warp.thread_rank() == 0) {\n        int ix = targets[b * T + t];\n        float prob = expf(logits[idx * P + ix] - sp.Offset) * sp.Scale;\n        losses[b * T + t] = -logf(prob);\n    }\n\n    // finally all threads calculate the gradients\n    // prob is only materialized here temporarily and in registers, never\n    // as a full tensor that gets written to global memory\n    for (int i = warp.thread_rank(); i < V; i += warp.size()) {\n        float prob = expf(logits[idx * P + i] - sp.Offset) * sp.Scale;\n        float* dlogits_bt = dlogits + b * T * P + t * P;\n        float dloss = dlosses[b * T + t];\n        int ix = targets[b * T + t];\n        float indicator = i == ix ? 1.0f : 0.0f;\n        dlogits_bt[i] = (prob - indicator) * dloss;\n    }\n}\n\n\n__device__ float vec_at(const float4& vec, int index) {\n    return reinterpret_cast<const float*>(&vec)[index];\n}\n\n__device__ SoftmaxParams prepare_softmax_blockwide(cg::thread_block_tile<32>& warp,\n                                                   int64_t idx, const float* inp, int V, int P) {\n    // one row of inp, i.e. inp[idx, :] of shape (V,)\n    // float4 to get 128-bit loads and memory level parallelism\n    const float4* x_vec4 = reinterpret_cast<const float4*>(inp + idx * P);\n\n    float thread_maxval = -INFINITY;\n    float thread_sumval = 0.0f;\n    // do the loop in reverse to maximise probability of L2 cache hits\n    // so even small L2s get some hits on the 2nd read of the same thread\n    for (int i = ceil_div(V, 4) + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) {\n        float4 v4 = x_vec4[i];\n        #pragma unroll\n        for(int k = 0; k < 4; k++) {\n            if (i*4+k >= V) {  // bounds checking against real V\n                continue;\n            }\n            float old_maxval = thread_maxval;\n            thread_maxval = fmaxf(thread_maxval, vec_at(v4, k));\n            thread_sumval *= expf(old_maxval - thread_maxval);\n            thread_sumval += expf(vec_at(v4, k) - thread_maxval);\n        }\n    }\n\n    // two reductions of up to 1024 threads:\n    // 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle)\n    // this results in much cleaner assembly than a multi-warp cg::reduce\n    __shared__ float shared_maxval[32];\n    __shared__ float shared_sumval[32];\n    int num_warps = blockDim.x / 32;\n    int warp_id = threadIdx.x / 32;\n    int lane_id = threadIdx.x % 32;\n\n    // reduce maxval within each warp\n    float warp_maxval = cg::reduce(warp, thread_maxval, cg::greater<float>{});\n    // thread 0 in each warp writes to shared memory\n    if (lane_id == 0) { shared_maxval[warp_id] = warp_maxval; }\n    __syncthreads();\n    // each thread now loads the maxval across previous warps\n    // if the thread is \"out of range\" of data, use -FLT_MAX as the maxval\n    warp_maxval = (lane_id < num_warps) ? shared_maxval[lane_id] : -FLT_MAX;\n    // now reduce the maxval among the warp threads\n    float block_maxval = cg::reduce(warp, warp_maxval, cg::greater<float>{});\n    // each thread uses maxval to scale sumval to avoid numerical instability / overflow\n    thread_sumval *= expf(thread_maxval - block_maxval);\n    // (warp-level) reduce sumval, thread 0 in each warp saves result in shared memory\n    float warp_sumval = cg::reduce(warp, thread_sumval, cg::plus<float>{});\n    if (lane_id == 0) { shared_sumval[warp_id] = warp_sumval; }\n    __syncthreads();\n    // same strategy, now reduce sumval across warps\n    warp_sumval = (lane_id < num_warps) ? shared_sumval[lane_id] : 0.0f;\n    float block_sumval = cg::reduce(warp, warp_sumval, cg::plus<float>{});\n    // return the softmax parameters\n    return SoftmaxParams{1.f / block_sumval, block_maxval};\n}\n\n// Fused forward and backward pass for classifier including softmax, and logit gradients\n// Writes to both probs (only for debugging) and dlogits (only for training) are optional\n// N.B.: We may want to reuse the logits memory for dlogits, so they should *not* be __restrict__!\n__global__ void fused_classifier_kernel2(float* dlogits, float* losses, float* probs,\n                                         const float* logits, const float* dlosses, const int* targets,\n                                         int B, int T, int V, int P) {\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int64_t idx = blockIdx.x;\n    int ix = targets[idx];\n\n    // softmax (reading B * T * V, same logits read again below, hopefully still in cache)\n    SoftmaxParams sp = prepare_softmax_blockwide(warp, idx, logits, V, P);\n\n    // calculate the probability needed for the loss and update (single-threaded)\n    if(threadIdx.x == 0) {\n        float prob = expf(logits[idx * P + ix] - sp.Offset) * sp.Scale;\n        losses[idx] = -logf(prob);\n    }\n\n    // very sensible default for dlosses is 1/(B*T), which is the uniform loss\n    float dloss = dlosses != NULL ? dlosses[idx] : 1.0f / (B*T);\n    // calculate the gradients directly, saves bandwidth from probs during training\n    // but also supports writing probs for inference-only and debugging\n    const float4* logits_vec4 = reinterpret_cast<const float4*>(logits + idx * P);\n    for (int i = threadIdx.x; i < ceil_div(V, 4); i += blockDim.x) {\n        // this is the 2nd read of logits after the one in prepare_softmax2\n        // this data will never be needed again, so we reduce cache persistence\n        float4 v4 = __ldcs(&logits_vec4[i]);\n\n        #pragma unroll\n        for(int k = 0; k < 4; ++k) {\n            int element = i*4 + k;\n            float prob = expf(vec_at(v4, k) - sp.Offset) * sp.Scale;\n            prob = (element < V) ? prob : 0.0f; // bounds checking against real V\n\n            // this kernel is DRAM limited so cost of inner branch is ~zero\n            if (probs != NULL) {\n                probs[idx * P + element] = prob;\n            }\n            if (dlogits != NULL) {\n                float indicator = element == ix ? 1.0f : 0.0f;\n                dlogits[idx * P + element] = (prob - indicator) * dloss;\n            }\n        }\n    }\n}\n\n__device__ SoftmaxParams prepare_softmax_blockwide_nofloat4(cg::thread_block_tile<32>& warp,\n                                                            int64_t idx, const float* inp, int V, int P) {\n    // same but not float4\n    // one row of inp, i.e. inp[idx, :] of shape (V,)\n\n    const float* x = inp + idx * P;\n    float thread_maxval = -INFINITY;\n    float thread_sumval = 0.0f;\n    // do the loop in reverse to maximise probability of L2 cache hits\n    // so even small L2s get some hits on the 2nd read of the same thread\n    for (int i = V + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) {\n        float v = x[i];\n        float old_maxval = thread_maxval;\n        thread_maxval = fmaxf(thread_maxval, v);\n        thread_sumval *= expf(old_maxval - thread_maxval);\n        thread_sumval += expf(v - thread_maxval);\n    }\n\n    // two reductions of up to 1024 threads:\n    // 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle)\n    // this results in much cleaner assembly than a multi-warp cg::reduce\n    __shared__ float shared_maxval[32];\n    __shared__ float shared_sumval[32];\n    int num_warps = blockDim.x / 32;\n    int warp_id = threadIdx.x / 32;\n    int lane_id = threadIdx.x % 32;\n\n    // reduce maxval within each warp\n    float warp_maxval = cg::reduce(warp, thread_maxval, cg::greater<float>{});\n    // thread 0 in each warp writes to shared memory\n    if (lane_id == 0) { shared_maxval[warp_id] = warp_maxval; }\n    __syncthreads();\n    // each thread now loads the maxval across previous warps\n    // if the thread is \"out of range\" of data, use -FLT_MAX as the maxval\n    warp_maxval = (lane_id < num_warps) ? shared_maxval[lane_id] : -FLT_MAX;\n    // now reduce the maxval among the warp threads\n    float block_maxval = cg::reduce(warp, warp_maxval, cg::greater<float>{});\n    // each thread uses maxval to scale sumval to avoid numerical instability / overflow\n    thread_sumval *= expf(thread_maxval - block_maxval);\n    // (warp-level) reduce sumval, thread 0 in each warp saves result in shared memory\n    float warp_sumval = cg::reduce(warp, thread_sumval, cg::plus<float>{});\n    if (lane_id == 0) { shared_sumval[warp_id] = warp_sumval; }\n    __syncthreads();\n    // same strategy, now reduce sumval across warps\n    warp_sumval = (lane_id < num_warps) ? shared_sumval[lane_id] : 0.0f;\n    float block_sumval = cg::reduce(warp, warp_sumval, cg::plus<float>{});\n    // return the softmax parameters\n    return SoftmaxParams{1.f / block_sumval, block_maxval};\n}\n\n// same as 2 but not using float4\n__global__ void fused_classifier_kernel3(float* dlogits, float* losses, float* probs,\n                                         const float* logits, const float* dlosses, const int* targets,\n                                         int B, int T, int V, int P) {\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int64_t idx = blockIdx.x;\n    int ix = targets[idx];\n\n    // softmax (reading B * T * V, same logits read again below, hopefully still in cache)\n    SoftmaxParams sp = prepare_softmax_blockwide_nofloat4(warp, idx, logits, V, P);\n\n    // calculate the probability needed for the loss and update (single-threaded)\n    if(threadIdx.x == 0) {\n        float prob = expf(logits[idx * P + ix] - sp.Offset) * sp.Scale;\n        losses[idx] = -logf(prob);\n    }\n\n    // very sensible default for dlosses is 1/(B*T), which is the uniform loss\n    float dloss = dlosses != NULL ? dlosses[idx] : 1.0f / (B*T);\n    // calculate the gradients directly, saves bandwidth from probs during training\n    // but also supports writing probs for inference-only and debugging\n    const float* logits_vec = logits + idx * P;\n    for (int i = threadIdx.x; i < V; i += blockDim.x) {\n        // this is the 2nd read of logits after the one in prepare_softmax2\n        // this data will never be needed again, so we reduce cache persistence\n        float v = __ldcs(&logits_vec[i]);\n        float prob = expf(v - sp.Offset) * sp.Scale;\n        if (probs != NULL) {\n            probs[idx * P + i] = prob;\n        }\n        if (dlogits != NULL) {\n            float indicator = (i == ix) ? 1.0f : 0.0f;\n            dlogits[idx * P + i] = (prob - indicator) * dloss;\n        }\n    }\n}\n\n__device__ SoftmaxParams prepare_softmax_blockwide2(int64_t idx, const floatX* inp, int V, int P) {\n    // one row of inp, i.e. inp[idx, :] of shape (V,)\n\n    const floatX* x = inp + idx * P;\n    float thread_maxval = -INFINITY;\n    float thread_sumval = 0.0f;\n    // do the loop in reverse to maximise probability of L2 cache hits\n    // so even small L2s get some hits on the 2nd read of the same thread\n    for (int i = ceil_div(V, x128::size) + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) {\n        x128 packed_x = load128cs(x + i * x128::size); // load and do not keep in cache\n        for(int k = 0; k < packed_x.size; ++k) {\n            if (i*x128::size+k >= V) {  // bounds checking against real V\n                continue;\n            }\n            float v = (float)packed_x[k];\n            float old_maxval = thread_maxval;\n            thread_maxval = fmaxf(thread_maxval, v);\n            thread_sumval *= expf(old_maxval - thread_maxval);\n            thread_sumval += expf(v - thread_maxval);\n        }\n    }\n    // two reductions of up to 1024 threads:\n    // 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle)\n    // this results in much cleaner assembly than a multi-warp cg::reduce\n    __shared__ float shared_maxval[32];\n    __shared__ float shared_sumval[32];\n    int num_warps = blockDim.x / 32;\n    int warp_id = threadIdx.x / 32;\n    int lane_id = threadIdx.x % 32;\n\n    // reduce maxval within each warp\n    float warp_maxval = warpReduceMax(thread_maxval);\n    // thread 0 in each warp writes to shared memory\n    if (lane_id == 0) { shared_maxval[warp_id] = warp_maxval; }\n    __syncthreads();\n    // each thread now loads the maxval across previous warps\n    // if the thread is \"out of range\" of data, use -FLT_MAX as the maxval\n    warp_maxval = (lane_id < num_warps) ? shared_maxval[lane_id] : -FLT_MAX;\n    // now reduce the maxval among the warp threads\n    float block_maxval = warpReduceMax(warp_maxval);\n    // each thread uses maxval to scale sumval to avoid numerical instability / overflow\n    thread_sumval *= expf(thread_maxval - block_maxval);\n    // (warp-level) reduce sumval, thread 0 in each warp saves result in shared memory\n    float warp_sumval = warpReduceSum(thread_sumval); //cg::reduce(warp, thread_sumval, cg::plus<float>{});\n\n    if (lane_id == 0) { shared_sumval[warp_id] = warp_sumval; }\n    __syncthreads();\n    // same strategy, now reduce sumval across warps\n    warp_sumval = (lane_id < num_warps) ? shared_sumval[lane_id] : 0.0f;\n    float block_sumval = warpReduceSum(warp_sumval); //cg::reduce(warp, thread_sumval, cg::plus<float>{});\n    // return the softmax parameters\n    return SoftmaxParams{1.f / block_sumval, block_maxval};\n}\n\n// same as 2 but using x128\n__global__ void fused_classifier_kernel4(floatX* dlogits, floatX* losses, floatX* probs,\n                                         const floatX* logits, const floatX* dlosses, const int* targets,\n                                         int B, int T, int V, int P) {\n    int64_t idx = blockIdx.x;\n    int ix = targets[idx];\n\n    // softmax (reading B * T * V, same logits read again below, hopefully still in cache)\n    SoftmaxParams sp = prepare_softmax_blockwide2(idx, logits, V, P);\n\n    // calculate the probability needed for the loss and update (single-threaded)\n    if(threadIdx.x == 0) {\n        float prob = expf((float)logits[idx * P + ix] - sp.Offset) * sp.Scale;\n        losses[idx] = -logf(prob);\n    }\n\n    // very sensible default for dlosses is 1/(B*T), which is the uniform loss\n    float dloss = dlosses != NULL ? (float)dlosses[idx] : 1.0f / (B*T);\n    // calculate the gradients directly, saves bandwidth from probs during training\n    // but also supports writing probs for inference-only and debugging\n    const floatX* logits_vec = logits + idx * P;\n    for (int i = threadIdx.x; i < ceil_div(V , x128::size); i += blockDim.x) {\n        // this is the 2nd read of logits after the one in prepare_softmax2\n        // this data will never be needed again, so we reduce cache persistence\n        x128 packed_logits_vec = load128cs(logits_vec + i * x128::size); // load and do not keep in cache\n        x128 packed_probs;\n        x128 packed_dlogits;\n        for(int k = 0; k < packed_logits_vec.size; ++k) {\n            int element = i*packed_logits_vec.size + k;\n            if (element >= V) {  // bounds checking against real V\n                continue;\n            }\n            float v = packed_logits_vec[k];\n            float prob = expf(v - sp.Offset) * sp.Scale;\n            packed_probs[k] = prob;\n            float indicator = (element == ix) ? 1.0f : 0.0f;\n            packed_dlogits[k] = (prob - indicator) * dloss;\n        }\n        // Note: missing .cs hint hurts our performance due to cache thrashing, fixed in kernel5\n        store128(dlogits + idx * P + i * packed_logits_vec.size, packed_dlogits);\n        if (probs != NULL) {\n            store128(probs + idx * P + i * packed_logits_vec.size, packed_probs);\n        }\n    }\n}\n\n__device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* inp, int V, int P) {\n    // same but not float4\n    // one row of inp, i.e. inp[idx, :] of shape (V,)\n\n    const floatX* x = inp + idx * P;\n    float thread_maxval = -INFINITY;\n    float thread_sumval = 0.0f;\n    int i = (V+x128::size-1)/x128::size + threadIdx.x - blockDim.x;\n\n    // special-case loop to handle the unaligned elements at the end of the array\n    // this lets us skip the bounds check in the main loop below, which improves performance\n    while ((i+1)*x128::size > V) {\n        for(int k = 0; k < x128::size; ++k) {\n            if (i*x128::size+k >= V) {\n                break; // bounds checking against real V (rather than padded P)\n            }\n            float v = (float)x[i*x128::size+k];\n            float old_maxval = thread_maxval;\n            thread_maxval = fmaxf(thread_maxval, v);\n            thread_sumval *= expf((old_maxval - thread_maxval));\n            thread_sumval += expf(v - thread_maxval);\n        }\n        i -= blockDim.x;\n    }\n\n    // main loop for the bulk of the iterations (no bounds checking required!)\n    for (; i >= 0; i -= blockDim.x) {\n        x128 packed_x = load128(x + i * x128::size); // load and keep in cache until fused_classifier loop\n        for(int k = 0; k < x128::size; ++k) {\n            float v = (float)packed_x[k];\n            float old_maxval = thread_maxval;\n            thread_maxval = fmaxf(thread_maxval, v);\n            thread_sumval *= expf((old_maxval - thread_maxval));\n            thread_sumval += expf(v - thread_maxval);\n        }\n    }\n\n    // Block Max Reduction -> Maths -> Block Sum Reduction\n    float block_maxval = blockReduce<warpReduceMax>(thread_maxval, false, -FLT_MAX);\n    thread_sumval *= expf(thread_maxval - block_maxval);\n    float block_sumval = blockReduce<warpReduceSum>(thread_sumval);\n\n    // return the softmax parameters\n    return SoftmaxParams{1.f / block_sumval, block_maxval};\n}\n\n// will _update_ logits to logit gradients\n// uses template to decide whether to write logits and probs\n// split both loops in \"multiple-of-x128-size\" and \"bounds-checked remainder\" parts\ntemplate <bool WriteLogits = true, bool WriteProbs = false>\n__global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS)\n                fused_classifier_kernel5(floatX* dlogits, floatX* losses, floatX* probs,\n                                         const floatX* logits, const floatX* dlosses, const int* targets,\n                                         int B, int T, int V, int P) {\n    int64_t idx = blockIdx.x;\n    int ix = targets[idx];\n\n    // softmax (reading B * T * V, same logits read again below, hopefully still in cache)\n    SoftmaxParams sp = prepare_softmax_blockwide3(idx, logits, V, P);\n\n    // calculate the probability needed for the loss and update (single-threaded)\n    if(threadIdx.x == 0) {\n        float prob = expf((float)logits[idx * P + ix] - sp.Offset) * sp.Scale;\n        losses[idx] = (floatX)(-logf(prob));\n    }\n\n    // very sensible default for dlosses is 1/(B*T), which is the uniform loss\n    float dloss = (dlosses != NULL) ? (float)dlosses[idx] : 1.0f / (B*T);\n    // calculate the gradients directly, saves bandwidth from probs during training\n    // but also supports writing probs for inference-only and debugging\n    const floatX* logits_vec = logits + idx * P;\n    for (int i = threadIdx.x; i < V/x128::size; i += blockDim.x) {\n        // this is the 2nd read of logits after the one in prepare_softmax2\n        // it will be overwritten by the logits gradients which is when we reduce cache persistence\n        x128 packed_logits_vec = load128(logits_vec + i * x128::size); // rely on cs of store128cs\n        x128 packed_probs;\n        for(int k = 0; k < x128::size; ++k) {\n            int element = i*x128::size + k;\n            float prob = expf((float)packed_logits_vec[k] - sp.Offset) * sp.Scale;\n            packed_probs[k] = (floatX)prob;\n            float indicator = (element == ix) ? 1.0f : 0.0f;\n            packed_logits_vec[k] = (floatX)((prob - indicator) * dloss);\n        }\n        if (WriteLogits){\n            // reduce cache persistence for the overwritten logits\n            // to maximise probability that logits remain in cache between prepare_softmax and here\n            store128cs(dlogits + idx * P + i * x128::size, packed_logits_vec);\n        }\n        if (WriteProbs) {\n            store128(probs + idx * P + i * x128::size, packed_probs);\n        }\n    }\n\n    // handle remaining elements after the last multiple of x128::size\n    // e.g. if V = 8003, and x128::size = 8, we need to handle the last 3 elements\n    int unaligned_start = V & ~(x128::size - 1); // round down to multiple of x128::size\n    for (int i = threadIdx.x + unaligned_start; i < V; i++) {\n        float prob = expf((float)logits_vec[i] - sp.Offset) * sp.Scale;\n        float indicator = (i == ix) ? 1.0f : 0.0f;\n        float dlogit = (prob - indicator) * dloss;\n        if (WriteLogits){\n            __stcs(dlogits + idx * P + i, (floatX)dlogit);\n        }\n        if (WriteProbs) {\n            probs[idx * P + i] = (floatX)prob;\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launcher\n\nvoid fused_classifier1(float* dlogits, float* losses,\n                      const float* logits, const float* dlosses, const int* targets,\n                      int B, int T, int V, int P, int block_size) {\n    const int N = B * T; // total number of rows in the input\n    // how many rows of the input can each block of threads process?\n    // e.g. in block_size=128, 4 rows get handled by 4 warps (of 32 threads each)\n    const int rows_per_block = block_size / 32;\n    const int grid_size = N / rows_per_block; // total number of blocks needed\n    fused_classifier_kernel1<<<grid_size, block_size>>>(dlogits, losses, logits, dlosses, targets, B, T, V, P);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid fused_classifier2(float* dlogits, float* losses,\n                      const float* logits, const float* dlosses, const int* targets,\n                      int B, int T, int V, int P, int block_size) {\n    const int N = B * T;\n    const int grid_size = N;\n    fused_classifier_kernel2<<<grid_size, block_size>>>(dlogits, losses, NULL, logits, dlosses, targets, B, T, V, P);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid fused_classifier3(float* dlogits, float* losses,\n                      const float* logits, const float* dlosses, const int* targets,\n                      int B, int T, int V, int P, int block_size) {\n    const int N = B * T;\n    const int grid_size = N;\n    fused_classifier_kernel3<<<grid_size, block_size>>>(dlogits, losses, NULL, logits, dlosses, targets, B, T, V, P);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid fused_classifier4(float* dlogits, float* losses,\n                      const float* logits, const float* dlosses, const int* targets,\n                      int B, int T, int V, int P, int block_size) {\n    const int N = B * T;\n    const int grid_size = N;\n    fused_classifier_kernel4<<<grid_size, block_size>>>((floatX*)dlogits, (floatX*)losses, NULL, (floatX*)logits, (floatX*)dlosses, targets, B, T, V, P);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid fused_classifier5(float* dlogits, float* losses,\n                      const float* logits, const float* dlosses, const int* targets,\n                      int B, int T, int V, int P, int block_size) {\n    const int N = B * T;\n    const int grid_size = N;\n    fused_classifier_kernel5<true,false><<<grid_size, block_size>>>((floatX*)dlogits, (floatX*)losses, NULL, (floatX*)logits, (floatX*)dlosses, targets, B, T, V, P);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid fused_classifier(int kernel_num, float* dlogits, float* losses,\n                      const float* logits, const float* dlosses, const int* targets,\n                      int B, int T, int V, int P, int block_size) {\n    switch (kernel_num) {\n        case 1:\n            fused_classifier1(dlogits, losses, logits, dlosses, targets, B, T, V, P, block_size);\n            break;\n        case 2:\n            fused_classifier2(dlogits, losses, logits, dlosses, targets, B, T, V, P, block_size);\n            break;\n        case 3:\n            fused_classifier3(dlogits, losses, logits, dlosses, targets, B, T, V, P, block_size);\n            break;\n        case 4:\n            fused_classifier4(dlogits, losses, logits, dlosses, targets, B, T, V, P, block_size);\n            break;\n        case 5:\n            fused_classifier5(dlogits, losses, logits, dlosses, targets, B, T, V, P, block_size);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n// ----------------------------------------------------------------------------\n\nint main(int argc, char **argv) {\n    srand(0);\n\n    int64_t B = 8;              // batch size\n    int64_t T = 1024;           // sequence length\n    int64_t V = 50257;          // vocab size\n    int64_t P = (V + 63) & ~63; // padded vocab size, up to nearest multiple of 64\n\n    int deviceIdx = 0;\n    cudaCheck(cudaSetDevice(deviceIdx));\n\n    // create host memory of random numbers\n    float* logits = make_random_float(B * T * V);\n    float* probs = make_random_float_01(B * T * V);\n    float* dlogits = (float*)malloc(B * T * V * sizeof(float));\n    float* losses = (float*)malloc(B * T * sizeof(float));\n    float* dlosses = make_random_float(B * T);\n    int* targets = make_random_int(B * T, V);\n    // make the input less uniformly random: Otherwise, all probabilities will be basically zero,\n    // and the tests are not actually meaningful.\n    int* outliers = make_random_int(B * T * 3, V);\n    for(int k = 0; k < 3; ++k) {\n        for(int j = 0; j < B * T; ++j) {\n            logits[j * V +  outliers[j*3 + k]] *= 20;\n        }\n    }\n\n    // move to GPU\n    int *d_targets;\n    float *d_logits, *d_losses;\n    float *d_dlogits, *d_dlosses, *d_dlogits_no_pad;\n    cudaCheck(cudaMalloc(&d_dlogits, B * T * P * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_logits, B * T * P * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_dlogits_no_pad, B * T * V * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_targets, B * T * sizeof(int)));\n    cudaCheck(cudaMalloc(&d_losses, B * T * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_dlosses, B * T * sizeof(float)));\n\n    // move to GPU\n    cudaCheck(cudaMemset(d_logits, 0xff, B * T * P * sizeof(float)));\n    cudaCheck(cudaMemcpy2D(d_logits, P * sizeof(float), logits, V * sizeof(float), V * sizeof(float), B * T, cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_dlosses, dlosses, B * T * sizeof(float), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice));\n\n    // read kernel_num from command line\n    int kernel_num = 1;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    // define block sizes we'll use in correctness and timing\n    int block_sizes[] = {32, 64, 128, 256, 512, 1024};\n\n    // first check the correctness of the kernel\n    softmax_forward_cpu(probs, logits, B * T, V);\n    crossentropy_forward_cpu(losses, probs, targets, B, T, V);\n    crossentropy_softmax_backward_cpu(dlogits, dlosses, probs, targets, B, T, V);\n\n#if defined(ENABLE_BF16) || defined(ENABLE_FP16)\n    if (kernel_num < 4) // kernel 4/5 + BF16 is only for testing performance, it doesn't do the format conversions yet etc...\n#endif\n    {\n        // time the kernel at different block sizes\n        for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n            int block_size = block_sizes[j];\n            printf(\"Checking block size %d.\\n\", block_size);\n            fused_classifier(kernel_num, d_dlogits, d_losses, d_logits, d_dlosses, d_targets, B, T, V, P, block_size);\n            validate_result(d_losses, losses, \"losses\", B * T, 1e-4f);\n            // undo the padding before we can check for correctness\n            cudaCheck(cudaMemcpy2D(d_dlogits_no_pad, V * sizeof(float), d_dlogits, P * sizeof(float), V * sizeof(float), B * T, cudaMemcpyDeviceToDevice));\n            validate_result(d_dlogits_no_pad, dlogits, \"dlogits\", B * T * V, 1e-4f);\n        }\n        printf(\"All results match. Starting benchmarks.\\n\\n\");\n    }\n\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        int repeat_times = 1000;\n        float elapsed_time = benchmark_kernel(repeat_times, fused_classifier,\n                                              kernel_num, d_dlogits, d_losses, d_logits, d_dlosses, d_targets,\n                                              B, T, V, P, block_size);\n        printf(\"block_size %4d | time %f ms\\n\", block_size, elapsed_time);\n    }\n\n    // free memory\n    free(logits);\n    free(probs);\n    free(dlogits);\n    free(losses);\n    free(dlosses);\n    free(targets);\n    free(outliers);\n    cudaCheck(cudaFree(d_dlogits));\n    cudaCheck(cudaFree(d_losses));\n    cudaCheck(cudaFree(d_logits));\n    cudaCheck(cudaFree(d_dlosses));\n    cudaCheck(cudaFree(d_targets));\n    cudaCheck(cudaFree(d_dlogits_no_pad));\n\n    return 0;\n}"
  },
  {
    "path": "dev/cuda/common.h",
    "content": "#include <stdlib.h>\n#include <stdio.h>\n#include <cuda_runtime.h>\n#include <cublas_v2.h>\n#include <cublasLt.h>\n#include <float.h>\n\n#define WARP_SIZE 32U\nextern cudaDeviceProp deviceProp;\n\ntemplate<class T>\n__host__ __device__ T ceil_div(T dividend, T divisor) {\n    return (dividend + divisor-1) / divisor;\n}\n\n__device__ float warpReduceSum(float val) {\n    for (int offset = 16; offset > 0; offset /= 2) {\n        val += __shfl_xor_sync(0xFFFFFFFF, val, offset);\n    }\n    return val;\n}\n\n// requires all 32 threads in the warp to be active, but should work for any block size\n// uses non-dynamic shared memory so every call increases shared memory requirements by 128 bytes\n// the fact it's unique shared memory allows us to avoid an extra __syncthreads() call at the end\n// but if called inside a loop, the shared memory will be implicitly reused, so set final_sync to 1\nusing reduction_func_t = float (*) (float);\n\ntemplate<reduction_func_t warp_reduction>\n__device__ inline float blockReduce(float val, bool final_sync, float out_of_bounds) {\n    // two reductions of up to 1024 threads:\n    // 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle)\n    __shared__ float shared_val[WARP_SIZE];\n    const int lane_id = threadIdx.x % WARP_SIZE;\n    const int warp_id = threadIdx.x / WARP_SIZE;\n    const int num_warps = blockDim.x / WARP_SIZE;\n\n    float warp_val = warp_reduction(val);\n    if (lane_id == 0) { shared_val[warp_id] = warp_val; }\n    __syncthreads();\n    warp_val = (lane_id < num_warps) ? shared_val[lane_id] : out_of_bounds;\n    float block_val = warp_reduction(warp_val);\n\n    if (final_sync) {\n        __syncthreads(); // only needed in loops when effectively reusing shared memory etc.\n    }\n    return block_val;\n}\n\n// Helper function to call blockReduce with default arguments\ntemplate<reduction_func_t warp_reduction>\n__device__ inline float blockReduce(float val) {\n    return blockReduce<warp_reduction>(val, false, 0.0f);\n}\n\n// ----------------------------------------------------------------------------\n// checking utils\n\n// CUDA error checking\nvoid cuda_check(cudaError_t error, const char *file, int line) {\n    if (error != cudaSuccess) {\n        printf(\"[CUDA ERROR] at file %s:%d:\\n%s\\n\", file, line,\n               cudaGetErrorString(error));\n        exit(EXIT_FAILURE);\n    }\n};\n#define cudaCheck(err) (cuda_check(err, __FILE__, __LINE__))\n\n// cuBLAS error checking\nvoid cublasCheck(cublasStatus_t status, const char *file, int line)\n{\n    if (status != CUBLAS_STATUS_SUCCESS) {\n        printf(\"[cuBLAS ERROR]: %d %s %d\\n\", status, file, line);\n        exit(EXIT_FAILURE);\n    }\n}\n#define cublasCheck(status) { cublasCheck((status), __FILE__, __LINE__); }\n\n// ----------------------------------------------------------------------------\n// cuBLAS setup\n// these will be initialized by setup_main\n\n// cuBLAS workspace. Hardcoding to 32MiB but only Hopper needs 32, for others 4 is OK\nstatic size_t cublaslt_workspace_size = 32 * 1024 * 1024;\nstatic void* cublaslt_workspace = NULL;\nstatic cublasComputeType_t cublas_compute_type;\ncublasHandle_t cublas_handle;\ncublasLtHandle_t cublaslt_handle;\nint cuda_arch_major = 0;\nint cuda_arch_minor = 0;\nint cuda_num_SMs = 0; // for persistent threads where we want 1 threadblock per SM\nint cuda_threads_per_SM = 0;    // needed to calculate how many blocks to launch to fill up the GPU\n\n// ----------------------------------------------------------------------------\n// to make sure that 2 blocks fit on A100/H100 to maximise latency tolerance\n#if __CUDA_ARCH__ == 800 || __CUDA_ARCH__ >= 900\n#define MAX_1024_THREADS_BLOCKS 2\n#else\n#define MAX_1024_THREADS_BLOCKS 1\n#endif\n\n// ----------------------------------------------------------------------------\n// Packed128 data structure, which forces the compiler to use 128-bit loads/stores\n// in GPUs that support (the LDG.128 and STS.128 instructions)\n// This is a bit similar to the use of float4 in the case of 32-bit floats, but\n// supports arbitrary precision.\n\ntemplate<class ElementType>\nstruct alignas(16) Packed128 {\n    // Note: = default implicitly generates a __device__ function, but explicitly\n    // adding __device__ causes a lot of warnings.\n    Packed128() = default;\n    __device__ explicit Packed128(int4 bits) {\n        static_assert(sizeof(bits) == sizeof(payload), \"Size mismatch.\");\n        memcpy(&payload, &bits, sizeof(bits));\n    }\n\n    __device__  static Packed128 constant(ElementType value) {\n        Packed128 result;\n        for(int k = 0; k < size; ++k) {\n            result.payload[k] = value;\n        }\n        return result;\n    }\n\n    __device__ static Packed128 zeros() {\n        return constant(0);\n    }\n\n    __device__ static Packed128 ones() {\n        return constant(1);\n    }\n\n    __device__ ElementType& operator[](int index) {\n        return payload[index];\n    }\n    __device__ const ElementType& operator[](int index) const {\n        return payload[index];\n    }\n    __device__ int4 get_bits() const {\n        int4 bits;\n        static_assert(sizeof(bits) == sizeof(payload), \"Size mismatch.\");\n        memcpy(&bits, &payload, sizeof(bits));\n        return bits;\n    }\n    // e.g. sizeof(int4) is 16 (4 X 4 bytes), sizeof(bfloat16) = 2, so size = 8\n    // so in the case where ElementType = bfloat16, we store 8 elements in one Packed128\n    static constexpr const int size = sizeof(int4) / sizeof(ElementType);\n    ElementType payload[size];\n};\n\n// short-form typedef\ntypedef Packed128<float> f128;\n\n// load a Packed128 from an aligned memory address\ntemplate<class ElementType>\n__device__ Packed128<ElementType> load128(const ElementType* address) {\n    return Packed128<ElementType>{*reinterpret_cast<const int4*>(address)};\n}\n// load a Packed128 from an aligned memory address with streaming cache hint\ntemplate<class ElementType>\n__device__ Packed128<ElementType> load128cs(const ElementType* address) {\n    return Packed128<ElementType>{__ldcs(reinterpret_cast<const int4*>(address))};\n}\n// store a Packed128 to an aligned memory address\ntemplate<class ElementType>\n__device__ void store128(ElementType* target, Packed128<ElementType> value) {\n    *reinterpret_cast<int4*>(target) = value.get_bits();\n}\n// store a Packed128 to an aligned memory address with streaming cache hint\ntemplate<class ElementType>\n__device__ void store128cs(ElementType* target, Packed128<ElementType> value) {\n    __stcs(reinterpret_cast<int4*>(target), value.get_bits());\n}\n// store a Packed128 to an aligned memory address while caching in L2 but bypassing L1\ntemplate<class ElementType>\n__device__ void store128cg(ElementType* target, Packed128<ElementType> value) {\n    __stcg(reinterpret_cast<int4*>(target), value.get_bits());\n}\n\n// ----------------------------------------------------------------------------\n// reduced/mixed precision utilities\n\n#if defined(ENABLE_BF16)\n\ntypedef __nv_bfloat16 floatX;\ntypedef __nv_bfloat16 floatN;\n#define CUBLAS_LOWP CUDA_R_16BF // CUDA_R_16F or CUDA_R_16BF (or CUDA_R_32F)\n// CUBLAS_COMPUTE_32F or CUBLAS_COMPUTE_16F (for CUDA_R_16F only, potentially slower?!)\n#define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F\n\n#elif defined(ENABLE_FP16)\n\ntypedef half floatX;\ntypedef half floatN;\n\n#else\n\ntypedef float floatX;\ntypedef float floatN;\n#endif\n\ntypedef Packed128<floatX> x128;\n\n\n// older nvcc does not provide __ldcs and __stcs for bfloat16, despite these actually just being unsigned shorts.\n// we need to be careful here to only define our own versions if none already exist, otherwise the compiler will\n// complain.\n// If not, you easily get \"no viable overload\" (for sm52) and \"function already exists\" (sm_80)\n#if defined(ENABLE_BF16) && (__CUDACC_VER_MAJOR__ < 12) && !((__CUDA_ARCH__ >= 800) || !defined(__CUDA_ARCH__))\n__device__ floatX __ldcs(const floatX* address) {\n    unsigned short bf = __ldcs(reinterpret_cast<const unsigned short*>(address));\n    return __nv_bfloat16_raw{bf};\n}\n\n__device__ void __stcs(floatX* address, floatX value) {\n    __stcs(reinterpret_cast<unsigned short*>(address), ((__nv_bfloat16_raw)value).x);\n}\n#endif\n\n\n// ----------------------------------------------------------------------------\n// random utils\n\nfloat* make_random_float_01(size_t N) {\n    float* arr = (float*)malloc(N * sizeof(float));\n    for (size_t i = 0; i < N; i++) {\n        arr[i] = ((float)rand() / RAND_MAX); // range 0..1\n    }\n    return arr;\n}\n\nfloat* make_random_float(size_t N) {\n    float* arr = (float*)malloc(N * sizeof(float));\n    for (size_t i = 0; i < N; i++) {\n        arr[i] = ((float)rand() / RAND_MAX) * 2.0 - 1.0; // range -1..1\n    }\n    return arr;\n}\n\nint* make_random_int(size_t N, int V) {\n    int* arr = (int*)malloc(N * sizeof(int));\n    for (size_t i = 0; i < N; i++) {\n        arr[i] = rand() % V; // range 0..V-1\n    }\n    return arr;\n}\n\nfloat* make_zeros_float(size_t N) {\n    float* arr = (float*)malloc(N * sizeof(float));\n    memset(arr, 0, N * sizeof(float)); // all zero\n    return arr;\n}\n\nfloat* make_ones_float(size_t N) {\n    float* arr = (float*)malloc(N * sizeof(float));\n    for (size_t i = 0; i < N; i++) {\n        arr[i] = 1.0f;\n    }\n    return arr;\n}\n\n// ----------------------------------------------------------------------------\n// testing and benchmarking utils\n\ntemplate<class TargetType>\n[[nodiscard]] cudaError_t memcpy_convert(TargetType* d_ptr, float* h_ptr, size_t count) {\n    // copy from host to device with data type conversion.\n    TargetType* converted = (TargetType*)malloc(count * sizeof(TargetType));\n    for (int i = 0; i < count; i++) {\n        converted[i] = (TargetType)h_ptr[i];\n    }\n\n    cudaError_t status = cudaMemcpy(d_ptr, converted, count * sizeof(TargetType), cudaMemcpyHostToDevice);\n    free(converted);\n\n    // instead of checking the status at cudaMemcpy, we return it from here. This way, we\n    // still need to use our checking macro, and get better line info as to where the error\n    // happened.\n    return status;\n}\n\nvoid setup_main() {\n    srand(0);   // determinism\n\n    // set up the device\n    int deviceIdx = 0;\n    cudaCheck(cudaSetDevice(deviceIdx));\n    cudaDeviceProp deviceProp;\n    cudaGetDeviceProperties(&deviceProp, deviceIdx);\n    cuda_num_SMs = deviceProp.multiProcessorCount;\n    cuda_threads_per_SM = deviceProp.maxThreadsPerMultiProcessor;\n    cuda_arch_major = deviceProp.major;\n    cuda_arch_minor = deviceProp.minor;\n\n    // setup cuBLAS and cuBLASLt\n    cublasCheck(cublasCreate(&cublas_handle));\n    cublasCheck(cublasLtCreate(&cublaslt_handle));\n    cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size));\n\n    // TF32 precision is equivalent to torch.set_float32_matmul_precision('high')\n    int enable_tf32 = cuda_arch_major >= 8 ? 1 : 0;\n    // TODO implement common CLI for all tests/benchmarks\n    // if (override_enable_tf32 == 0) { enable_tf32 = 0; } // force to zero via arg\n    cublas_compute_type = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F;\n    cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH;\n    cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode));\n}\n\ntemplate<class D, class T>\nvoid validate_result(D* device_result, const T* cpu_reference, const char* name, std::size_t num_elements, T tolerance=1e-4) {\n    D* out_gpu = (D*)malloc(num_elements * sizeof(D));\n    cudaCheck(cudaMemcpy(out_gpu, device_result, num_elements * sizeof(D), cudaMemcpyDeviceToHost));\n    int nfaults = 0;\n#ifndef ENABLE_BF16\n    float epsilon = FLT_EPSILON;\n#else\n    float epsilon = 0.079;\n#endif\n    for (int i = 0; i < num_elements; i++) {\n        // Skip masked elements\n        if(!isfinite(cpu_reference[i]))\n            continue;\n\n        // print the first few comparisons\n        if (i < 5) {\n            printf(\"%f %f\\n\", cpu_reference[i], (T)out_gpu[i]);\n        }\n        // effective tolerance is based on expected rounding error (epsilon),\n        // plus any specified additional tolerance\n        float t_eff = tolerance + fabs(cpu_reference[i]) * epsilon;\n        // ensure correctness for all elements.\n        if (fabs(cpu_reference[i] - (T)out_gpu[i]) > t_eff) {\n            printf(\"Mismatch of %s at %d: CPU_ref: %f vs GPU: %f\\n\", name, i, cpu_reference[i], (T)out_gpu[i]);\n            nfaults ++;\n            if (nfaults >= 10) {\n                free(out_gpu);\n                exit(EXIT_FAILURE);\n            }\n        }\n    }\n\n    if (nfaults > 0) {\n        free(out_gpu);\n        exit(EXIT_FAILURE);\n    }\n\n    free(out_gpu);\n}\n\ntemplate<class Kernel, class... KernelArgs>\nfloat benchmark_kernel(int repeats, Kernel kernel, KernelArgs&&... kernel_args) {\n    cudaEvent_t start, stop;\n    // prepare buffer to scrub L2 cache between benchmarks\n    // just memset a large dummy array, recommended by\n    // https://stackoverflow.com/questions/31429377/how-can-i-clear-flush-the-l2-cache-and-the-tlb-of-a-gpu\n    // and apparently used in nvbench.\n    int deviceIdx = 0;\n    cudaCheck(cudaSetDevice(deviceIdx));\n    cudaDeviceProp deviceProp;\n    cudaCheck(cudaGetDeviceProperties(&deviceProp, deviceIdx));\n    void* flush_buffer;\n    cudaCheck(cudaMalloc(&flush_buffer, deviceProp.l2CacheSize));\n\n    cudaCheck(cudaEventCreate(&start));\n    cudaCheck(cudaEventCreate(&stop));\n    float elapsed_time = 0.f;\n    for (int i = 0; i < repeats; i++) {\n        // clear L2\n        cudaCheck(cudaMemset(flush_buffer, 0, deviceProp.l2CacheSize));\n        // now we can start recording the timing of the kernel\n        cudaCheck(cudaEventRecord(start, nullptr));\n        kernel(std::forward<KernelArgs>(kernel_args)...);\n        cudaCheck(cudaEventRecord(stop, nullptr));\n        cudaCheck(cudaEventSynchronize(start));\n        cudaCheck(cudaEventSynchronize(stop));\n        float single_call;\n        cudaCheck(cudaEventElapsedTime(&single_call, start, stop));\n        elapsed_time += single_call;\n    }\n\n    cudaCheck(cudaFree(flush_buffer));\n\n    return elapsed_time / repeats;\n}"
  },
  {
    "path": "dev/cuda/crossentropy_forward.cu",
    "content": "/*\nKernels for crossentropy forward pass.\n\nCompile example:\nnvcc -O3 --use_fast_math -lcublas -lcublasLt crossentropy_forward.cu -o crossentropy_forward\n\nversion 1 is a straight-forward port from CPU code to kernel, parallel over B,T\n./crossentropy_forward 1\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <cuda_runtime.h>\n#include \"common.h\"\n\n// ----------------------------------------------------------------------------\n// CPU code reference\n\nvoid crossentropy_forward_cpu(float* losses,\n                            const float* probs, const int* targets,\n                            int B, int T, int V) {\n    // output: losses is (B,T) of the individual losses at each position\n    // input: probs are (B,T,V) of the probabilities\n    // input: targets is (B,T) of integers giving the correct index in logits\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            // loss = -log(probs[target])\n            const float* probs_bt = probs + b * T * V + t * V;\n            int ix = targets[b * T + t];\n            losses[b * T + t] = -logf(probs_bt[ix]);\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n\n__global__ void crossentropy_forward_kernel1(float* losses,\n                            const float* probs, const int* targets,\n                            int B, int T, int V) {\n    int i = blockIdx.x * blockDim.x + threadIdx.x;\n    if (i < B * T) {\n        int b = i / T;\n        int t = i % T;\n        const float* probs_bt = probs + b * T * V + t * V;\n        int ix = targets[b * T + t];\n        losses[b * T + t] = -logf(probs_bt[ix]);\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launcher\n\nvoid crossentropy_forward1(float* losses,\n                            const float* probs, const int* targets,\n                            int B, int T, int V,\n                            const int block_size) {\n    const int N = B * T;\n    const int grid_size = ceil_div(N, block_size);\n    crossentropy_forward_kernel1<<<grid_size, block_size>>>(losses, probs, targets, B, T, V);\n    cudaCheck(cudaGetLastError());\n}\n\n// kernel version dispatch\nvoid crossentropy_forward(int kernel_num,\n                          float* losses,\n                          const float* probs, const int* targets,\n                          int B, int T, int V,\n                          const int block_size) {\n    switch (kernel_num) {\n        case 1:\n            crossentropy_forward1(losses, probs, targets, B, T, V, block_size);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n// ----------------------------------------------------------------------------\n\nint main(int argc, char **argv) {\n    srand(0);\n\n    int B = 8;\n    int T = 1024;\n    int V = 50257;\n\n    int deviceIdx = 0;\n    cudaCheck(cudaSetDevice(deviceIdx));\n\n    // create host memory of random numbers\n    float* out = (float*)malloc(B * T * sizeof(float));\n    float* probs = make_random_float_01(B * T * V);\n    int* targets = make_random_int(B * T, V);\n\n    // move to GPU\n    float* d_out;\n    float* d_probs;\n    int* d_targets;\n    cudaCheck(cudaMalloc(&d_out, B * T * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_probs, B * T * V * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_targets, B * T * sizeof(int)));\n    cudaCheck(cudaMemcpy(d_probs, probs, B * T * V * sizeof(float), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice));\n\n    // read kernel_num from command line\n    int kernel_num = 1;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    // first check the correctness of the kernel\n    crossentropy_forward_cpu(out, probs, targets, B, T, V);\n    // time the kernel at different block sizes\n    int block_sizes[] = {32, 64, 128, 256, 512, 1024};\n\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        printf(\"Checking block size %d.\\n\", block_size);\n        crossentropy_forward(kernel_num, d_out, d_probs, d_targets, B, T, V, block_size);\n        validate_result(d_out, out, \"out\", B * T, 1e-5f);\n    }\n\n    printf(\"All results match. Starting benchmarks.\\n\\n\");\n\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n\n        int repeat_times = 1000;\n        float elapsed_time = benchmark_kernel(repeat_times, crossentropy_forward,\n                                              kernel_num, d_out, d_probs, d_targets,\n                                              B, T, V, block_size);\n\n        printf(\"block_size %4d | time %.4f ms | per token %.2f ns\\n\", block_size, elapsed_time, elapsed_time * 1'000'000 / (B*T));\n    }\n\n    // free memory\n    free(out);\n    free(probs);\n    free(targets);\n    cudaCheck(cudaFree(d_out));\n    cudaCheck(cudaFree(d_probs));\n    cudaCheck(cudaFree(d_targets));\n\n    return 0;\n}"
  },
  {
    "path": "dev/cuda/crossentropy_softmax_backward.cu",
    "content": "/*\nKernels for crossentropy forward pass.\n\nCompile example:\nnvcc -O3 --use_fast_math -lcublas -lcublasLt crossentropy_softmax_backward.cu -o crossentropy_softmax_backward\n\nversion 1 is a straight-forward port from CPU code to kernel, parallel over B,T\n./crossentropy_softmax_backward 1\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <cuda_runtime.h>\n#include \"common.h\"\n\n// ----------------------------------------------------------------------------\n// CPU code reference\n\nvoid crossentropy_softmax_backward_cpu(float* dlogits,\n                           const float* dlosses, const float* probs, const int* targets,\n                           int B, int T, int V) {\n    // backwards through both softmax and crossentropy\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            float* dlogits_bt = dlogits + b * T * V + t * V;\n            const float* probs_bt = probs + b * T * V + t * V;\n            float dloss = dlosses[b * T + t];\n            int ix = targets[b * T + t];\n            for (int i = 0; i < V; i++) {\n                float p = probs_bt[i];\n                float indicator = i == ix ? 1.0f : 0.0f;\n                dlogits_bt[i] += (p - indicator) * dloss;\n            }\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n\n// naive kernel that just parallelizes over B,T,V\n__global__ void crossentropy_softmax_backward_kernel1(float* dlogits,\n                           const float* dlosses, const float* probs, const int* targets,\n                           int B, int T, int V) {\n    int i = blockIdx.x * blockDim.x + threadIdx.x;\n    if (i < B * T * V) {\n        int b = i / (T * V);\n        int t = (i / V) % T;\n        int v = i % V;\n        float* dlogits_bt = dlogits + b * T * V + t * V;\n        const float* probs_bt = probs + b * T * V + t * V;\n        float dloss = dlosses[b * T + t];\n        int ix = targets[b * T + t];\n        float p = probs_bt[v];\n        float indicator = v == ix ? 1.0f : 0.0f;\n        dlogits_bt[v] += (p - indicator) * dloss;\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launcher\n\nvoid crossentropy_softmax_backward1(float* dlogits,\n                           const float* dlosses, const float* probs, const int* targets,\n                           int B, int T, int V,\n                           const int block_size) {\n    const int N = B * T * V;\n    const int grid_size = ceil_div(N, block_size);\n    crossentropy_softmax_backward_kernel1<<<grid_size, block_size>>>(dlogits, dlosses, probs, targets, B, T, V);\n    cudaCheck(cudaGetLastError());\n}\n\n// kernel version dispatch\nvoid crossentropy_softmax_backward(int kernel_num,\n                           float* dlogits,\n                           const float* dlosses, const float* probs, const int* targets,\n                           int B, int T, int V,\n                           const int block_size) {\n    switch (kernel_num) {\n        case 1:\n            crossentropy_softmax_backward1(dlogits, dlosses, probs, targets, B, T, V, block_size);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n// ----------------------------------------------------------------------------\n\nint main(int argc, char **argv) {\n    srand(0);\n\n    int B = 8;\n    int T = 1024;\n    int V = 50257;\n\n    int deviceIdx = 0;\n    cudaCheck(cudaSetDevice(deviceIdx));\n\n    // create host memory of random numbers\n    float* probs = make_random_float_01(B * T * V);\n    int* targets = make_random_int(B * T, V);\n    float* dlosses = make_random_float(B * T);\n    float* dlogits = make_zeros_float(B * T * V);\n\n    // move to GPU\n    float* d_probs;\n    int* d_targets;\n    float* d_dlosses;\n    float* d_dlogits;\n    cudaCheck(cudaMalloc(&d_probs, B * T * V * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_targets, B * T * sizeof(int)));\n    cudaCheck(cudaMalloc(&d_dlosses, B * T * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_dlogits, B * T * V * sizeof(float)));\n    cudaCheck(cudaMemcpy(d_probs, probs, B * T * V * sizeof(float), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_dlosses, dlosses, B * T * sizeof(float), cudaMemcpyHostToDevice));\n\n    // read kernel_num from command line\n    int kernel_num = 1;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    // first check the correctness of the kernel\n    crossentropy_softmax_backward_cpu(dlogits, dlosses, probs, targets, B, T, V);\n\n    // time the kernel at different block sizes\n    int block_sizes[] = {32, 64, 128, 256, 512, 1024};\n\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        cudaCheck(cudaMemset(d_dlogits, 0, B * T * V * sizeof(float)));\n        printf(\"Checking block size %d.\\n\", block_size);\n        crossentropy_softmax_backward(kernel_num, d_dlogits, d_dlosses, d_probs, d_targets, B, T, V, block_size);\n        validate_result(d_dlogits, dlogits, \"dlogits\", B * T * V, 1e-5f);\n    }\n\n    printf(\"All results match. Starting benchmarks.\\n\\n\");\n\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n\n        int repeat_times = 100;\n        float elapsed_time = benchmark_kernel(repeat_times, crossentropy_softmax_backward,\n                                              kernel_num, d_dlogits, d_dlosses, d_probs, d_targets,\n                                              B, T, V, block_size);\n\n        printf(\"block_size %4d | time %.4f ms | per token %.2f µs\\n\", block_size, elapsed_time, elapsed_time * 1'000 / (B*T));\n    }\n\n    // free memory\n    free(probs);\n    free(targets);\n    free(dlosses);\n    free(dlogits);\n    cudaCheck(cudaFree(d_probs));\n    cudaCheck(cudaFree(d_targets));\n    cudaCheck(cudaFree(d_dlosses));\n    cudaCheck(cudaFree(d_dlogits));\n\n    return 0;\n}"
  },
  {
    "path": "dev/cuda/encoder_backward.cu",
    "content": "/*\nKernels for the positional encoder forward pass in GPT-2.\n\nCompile example:\nnvcc -O3 --use_fast_math -lcublas -lcublasLt encoder_backward.cu -o encoder_backward\n\nversion 1 is naive port from CPU code to kernel\nparallelizes over B,T,C, uses atomics to add to dwte, dwpe\n./encoder_backward 1\n\nversion 2 is another naive port\nparallelizes over C, loops over B,T; much slower than version 1\n./encoder_backward 2\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <cuda_runtime.h>\n#include \"common.h\"\n\n// ----------------------------------------------------------------------------\n// CPU code reference\n\n// GPT-2 positional encoder forward pass\nvoid encoder_backward_cpu(float* dwte, float* dwpe,\n                            float* dout, int* inp,\n                            int B, int T, int C) {\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            float* dout_bt = dout + b * T * C + t * C;\n            int ix = inp[b * T + t];\n            float* dwte_ix = dwte + ix * C;\n            float* dwpe_t = dwpe + t * C;\n            for (int i = 0; i < C; i++) {\n                float d = dout_bt[i];\n                dwte_ix[i] += d;\n                dwpe_t[i] += d;\n            }\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n\n// naive implementation with atomics\n__global__ void encoder_backward_kernel1(float* dwte, float* dwpe,\n                                        const float* dout, const int* inp,\n                                        int B, int T, int C) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    int N = B * T * C;\n\n    if (idx < N) {\n        int bt = idx / C;\n        int b = bt / T;\n        int t = bt % T;\n        int c = idx % C;\n\n        int ix = inp[b * T + t];\n\n        const float* dout_btc = dout + b * T * C + t * C + c;\n        float* dwte_ix = dwte + ix * C + c;\n        float* dwpe_tc = dwpe + t * C + c;\n\n        atomicAdd(dwte_ix, *dout_btc);\n        atomicAdd(dwpe_tc, *dout_btc);\n    }\n}\n\n// naive implementation that parallelizes over C and loops over B,T\n// but it gets rid of atomics\n__global__ void encoder_backward_kernel2(float* dwte, float* dwpe,\n                                        const float* dout, const int* inp,\n                                        int B, int T, int C) {\n    int c = blockIdx.x * blockDim.x + threadIdx.x;\n    if (c >= C) { return; } // guard\n    int BT = B * T;\n    for (int i = 0; i < BT; i++) {\n        int t = i % T;\n        int ix = inp[i];\n        float dout_btc = dout[i * C + c];\n        dwte[ix * C + c] += dout_btc;\n        dwpe[t * C + c] += dout_btc;\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launcher\n\nvoid encoder_backward1(float* dwte, float* dwpe,\n                    const float* dout, const int* inp,\n                    int B, int T, int C,\n                    const int block_size) {\n    const int N = B * T * C;\n    const int grid_size = ceil_div(N, block_size);\n    encoder_backward_kernel1<<<grid_size, block_size>>>(dwte, dwpe, dout, inp, B, T, C);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid encoder_backward2(float* dwte, float* dwpe,\n                    const float* dout, const int* inp,\n                    int B, int T, int C,\n                    const int block_size) {\n    const int grid_size = ceil_div(C, block_size);\n    encoder_backward_kernel2<<<grid_size, block_size>>>(dwte, dwpe, dout, inp, B, T, C);\n    cudaCheck(cudaGetLastError());\n}\n\n// kernel version dispatch\nvoid encoder_backward(int kernel_num,\n                     float* dwte, float* dwpe,\n                    const float* dout, const int* inp,\n                    int B, int T, int C,\n                    const int block_size) {\n    switch (kernel_num) {\n        case 1:\n            encoder_backward1(dwte, dwpe, dout, inp, B, T, C, block_size);\n            break;\n        case 2:\n            encoder_backward2(dwte, dwpe, dout, inp, B, T, C, block_size);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n// ----------------------------------------------------------------------------\n\nint main(int argc, char **argv) {\n    srand(0);\n\n    int B = 8;\n    int T = 1024;\n    int C = 768;\n    int V = 50257;\n\n    int deviceIdx = 0;\n    cudaCheck(cudaSetDevice(deviceIdx));\n\n    // create host memory of random numbers\n    float* dout = make_random_float(B * T * C);\n    int* inp = make_random_int(B * T, V);\n    float* dwte = make_zeros_float(V * C);\n    float* dwpe = make_zeros_float(T * C);\n\n    // move to GPU\n    float* d_dout;\n    int* d_inp;\n    float* d_dwte;\n    float* d_dwpe;\n    cudaCheck(cudaMalloc(&d_dout, B * T * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_inp, B * T * sizeof(int)));\n    cudaCheck(cudaMalloc(&d_dwte, V * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_dwpe, T * C * sizeof(float)));\n    cudaCheck(cudaMemcpy(d_dout, dout, B * T * C * sizeof(float), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_inp, inp, B * T * sizeof(int), cudaMemcpyHostToDevice));\n\n    // read kernel_num from command line\n    int kernel_num = 1;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    // first check the correctness of the kernel\n    encoder_backward_cpu(dwte, dwpe, dout, inp, B, T, C);\n\n    // time the kernel at different block sizes\n    int block_sizes[] = {32, 64, 128, 256, 512, 1024};\n\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        cudaCheck(cudaMemset(d_dwte, 0, V * C * sizeof(float)));\n        cudaCheck(cudaMemset(d_dwpe, 0, T * C * sizeof(float)));\n        printf(\"Checking block size %d.\\n\", block_size);\n        encoder_backward(kernel_num, d_dwte, d_dwpe, d_dout, d_inp, B, T, C, block_size);\n        validate_result(d_dwte, dwte, \"dwte\", V * C, 1e-5f);\n        validate_result(d_dwpe, dwpe, \"dwpe\", T * C, 1e-5f);\n    }\n    printf(\"All results match. Starting benchmarks.\\n\\n\");\n\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        int repeat_times = 1000;\n        float elapsed_time = benchmark_kernel(repeat_times, encoder_backward,\n                                              kernel_num, d_dwte, d_dwpe, d_dout, d_inp, B, T, C, block_size);\n        printf(\"block_size %4d | time %.4f ms\\n\", block_size, elapsed_time);\n    }\n\n    // free memory\n    free(dout);\n    free(inp);\n    free(dwte);\n    free(dwpe);\n    cudaFree(d_dout);\n    cudaFree(d_inp);\n    cudaFree(d_dwte);\n    cudaFree(d_dwpe);\n\n    return 0;\n}\n"
  },
  {
    "path": "dev/cuda/encoder_forward.cu",
    "content": "/*\nKernels for the positional encoder forward pass in GPT-2.\n\nCompile example:\nnvcc -O3 --use_fast_math -lcublas -lcublasLt encoder_forward.cu -o encoder_forward\n\nversion 1 is naive port from CPU code to kernel: parallelizes over B,T, loops over C\n./encoder_forward 1\n\nversion 2 is more optimized, parallelizes over all of B,T,C\n./encoder_forward 2\n\nversion 3 is like version 2 but uses float4 reads/writes\n./encoder_forward 3\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <cuda_runtime.h>\n#include <cassert>\n\n#define ENABLE_BF16\n#include \"common.h\"\n\n// ----------------------------------------------------------------------------\n// CPU code reference\n\n// GPT-2 positional encoder forward pass\nvoid encoder_forward_cpu(float* out,\n                   const int* inp, const float* wte, const float* wpe,\n                   int B, int T, int C) {\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            float* out_bt = out + b * T * C + t * C;\n            int ix = inp[b * T + t];\n            const float* wte_ix = wte + ix * C;\n            const float* wpe_t = wpe + t * C;\n            for (int i = 0; i < C; i++) {\n                out_bt[i] = wte_ix[i] + wpe_t[i];\n            }\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n\n// naive implementation into kernel, parallelize over B,T, loop over C\n__global__ void encoder_forward_kernel1(floatX* out,\n                               const int* inp, const floatX* wte, const floatX* wpe,\n                               int B, int T, int C) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    int N = B * T;\n\n    if (idx < N) {\n        int b = idx / T;\n        int t = idx % T;\n        floatX* out_bt = out + b * T * C + t * C;\n        int ix = inp[b * T + t];\n        const floatX* wte_ix = wte + ix * C;\n        const floatX* wpe_t = wpe + t * C;\n        for (int i = 0; i < C; i++) {\n            out_bt[i] = (floatX)((float)wte_ix[i] + (float)wpe_t[i]);\n        }\n    }\n}\n\n// optimized implementation: parallelize over all of B,T,C\n__global__ void encoder_forward_kernel2(floatX* out,\n                               const int* inp, const floatX* wte, const floatX* wpe,\n                               int B, int T, int C) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    int N = B * T * C;\n\n    if (idx < N) {\n        int bt = idx / C;\n        int b = bt / T;\n        int t = bt % T;\n        int c = idx % C;\n\n        int ix = inp[b * T + t];\n\n        floatX* out_btc = out + b * T * C + t * C + c;\n        const floatX* wte_ix = wte + ix * C + c;\n        const floatX* wpe_tc = wpe + t * C + c;\n        *out_btc = (floatX)((float)*wte_ix + (float)*wpe_tc);\n    }\n}\n\n__global__ void encoder_forward_kernel3(floatX* out,\n                               const int* inp, const floatX* wte, const floatX* wpe,\n                               int B, int T, int C) {\n    int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;\n    int N = B * T * C;\n    if (idx < N) {\n        int bt = idx / C;\n        int b = bt / T;\n        int t = bt % T;\n        int c = idx % C;\n\n        int ix = inp[b * T + t];\n\n        floatX* out_btc = out + b * T * C + t * C + c;\n        const floatX* wte_ix = wte + ix * C + c;\n        const floatX* wpe_tc = wpe + t * C + c;\n\n        x128 packed_out;\n        x128 wte = load128cs(wte_ix);\n        x128 wpe = load128cs(wpe_tc);\n        #pragma unroll\n        for (int k = 0; k < wte.size; k++) {\n            packed_out[k] = (floatX)((float)wte[k] + (float)wpe[k]);\n        }\n        store128(out_btc, packed_out);\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launcher\n\nvoid encoder_forward1(floatX* out,\n                     const int* inp, const floatX* wte, const floatX* wpe,\n                     int B, int T, int C,\n                     const int block_size) {\n    const int N = B * T;\n    const int grid_size = ceil_div(N, block_size);\n    encoder_forward_kernel1<<<grid_size, block_size>>>(out, inp, wte, wpe, B, T, C);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid encoder_forward2(floatX* out,\n                     const int* inp, const floatX* wte, const floatX* wpe,\n                     int B, int T, int C,\n                     const int block_size) {\n    const int N = B * T * C;\n    const int grid_size = ceil_div(N, block_size);\n    encoder_forward_kernel2<<<grid_size, block_size>>>(out, inp, wte, wpe, B, T, C);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid encoder_forward3(floatX* out,\n                     const int* inp, const floatX* wte, const floatX* wpe,\n                     int B, int T, int C,\n                     const int block_size) {\n    const int N = B * T * C;\n    const int grid_size = ceil_div(N, (int)(block_size * x128::size));\n    encoder_forward_kernel3<<<grid_size, block_size>>>(out, inp, wte, wpe, B, T, C);\n    cudaCheck(cudaGetLastError());\n}\n\n// kernel version dispatch\nvoid encoder_forward(int kernel_num,\n                     floatX* out,\n                     const int* inp, const floatX* wte, const floatX* wpe,\n                     int B, int T, int C,\n                     const int block_size) {\n    switch (kernel_num) {\n        case 1:\n            encoder_forward1(out, inp, wte, wpe, B, T, C, block_size);\n            break;\n        case 2:\n            encoder_forward2(out, inp, wte, wpe, B, T, C, block_size);\n            break;\n        case 3:\n            encoder_forward3(out, inp, wte, wpe, B, T, C, block_size);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n// ----------------------------------------------------------------------------\n\nint main(int argc, char **argv) {\n    setup_main();\n\n    int B = 8;\n    int T = 1024;\n    int C = 768;\n    int V = 50257;\n\n    int deviceIdx = 0;\n    cudaCheck(cudaSetDevice(deviceIdx));\n\n    // create host memory of random numbers\n    float* out = (float*)malloc(B * T * C * sizeof(float));\n    int* inp = make_random_int(B * T, V);\n    float* wte = make_random_float(V * C);\n    float* wpe = make_random_float(T * C);\n\n    // move to GPU\n    floatX* d_out;\n    int* d_inp;\n    floatX* d_wte;\n    floatX* d_wpe;\n    cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_inp, B * T * sizeof(int)));\n    cudaCheck(cudaMalloc(&d_wte, V * C * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_wpe, T * C * sizeof(floatX)));\n    cudaCheck(cudaMemcpy(d_inp, inp, B * T * sizeof(int), cudaMemcpyHostToDevice));\n    cudaCheck(memcpy_convert(d_wte, wte, V * C));\n    cudaCheck(memcpy_convert(d_wpe, wpe, T * C));\n\n    // read kernel_num from command line\n    int kernel_num = 2;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    // first check the correctness of the kernel\n    encoder_forward_cpu(out, inp, wte, wpe, B, T, C);\n\n    // time the kernel at different block sizes\n    int block_sizes[] = {32, 64, 128, 256, 512, 1024};\n\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        printf(\"Checking block size %d.\\n\", block_size);\n        encoder_forward(kernel_num, d_out, d_inp, d_wte, d_wpe, B, T, C, block_size);\n#if !defined(ENABLE_BF16) && !defined(ENABLE_FP16)\n        float tol = 1e-5;\n#else\n        float tol = 1e-2f;\n#endif\n        validate_result(d_out, out, \"out\", B * T * C, tol);\n    }\n\n    printf(\"All results match. Starting benchmarks.\\n\\n\");\n\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n\n        int repeat_times = 1000;\n        float elapsed_time = benchmark_kernel(repeat_times, encoder_forward,\n                                              kernel_num, d_out, d_inp, d_wte, d_wpe, B, T, C, block_size\n                                              );\n\n        // napkin math: estimate the memory bandwidth achieved\n        // for each (B,T,C) output element, we do 3 reads and 1 write, 4 bytes each\n        // and e.g. A100 40GB PCIe is advertised at 1,555GB/s\n        long memory_ops = B * T * C * 4 * 4;\n        float memory_bandwidth = memory_ops / elapsed_time / 1e6;\n\n        printf(\"block_size %4d | time %.4f ms | bandwidth %.2f GB/s\\n\", block_size, elapsed_time, memory_bandwidth);\n    }\n\n    // free memory\n    free(out);\n    free(inp);\n    free(wte);\n    free(wpe);\n    cudaCheck(cudaFree(d_out));\n    cudaCheck(cudaFree(d_inp));\n    cudaCheck(cudaFree(d_wte));\n    cudaCheck(cudaFree(d_wpe));\n\n    return 0;\n}"
  },
  {
    "path": "dev/cuda/fused_residual_forward.cu",
    "content": "/*\nKernels for residual forward pass fused with layernorm\n\nCompile example:\nnvcc -O3 --use_fast_math -lcublas -lcublasLt fused_residual_forward.cu -o fused_residual_forward\n\nversion 1 is naive port from CPU code to kernel\n./fused_residual_forward 1\nversion 2 packs input into 128 bit memory reads\n./fused_residual_forward 2\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include \"assert.h\"\n#include <cuda_runtime.h>\n\n#define ENABLE_BF16\n#include \"common.h\"\n\n// ----------------------------------------------------------------------------\n// CPU code reference lol\n\nvoid residual_forward_cpu(float* out, const float* inp1, const float* inp2, int N) {\n    for (int i = 0; i < N; i++) {\n        out[i] = inp1[i] + inp2[i];\n    }\n}\n\nvoid layernorm_forward_cpu(float* out, float* mean, float* rstd,\n                           const float* inp, const float* weight, const float* bias,\n                           int B, int T, int C) {\n    float eps = 1e-5f;\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            // seek to the input position inp[b,t,:]\n            const float* x = inp + b * T * C + t * C;\n            // calculate the mean\n            float m = 0.0f;\n            for (int i = 0; i < C; i++) {\n                m += x[i];\n            }\n            m = m/C;\n            // calculate the variance (without any bias correction)\n            float v = 0.0f;\n            for (int i = 0; i < C; i++) {\n                float xshift = x[i] - m;\n                v += xshift * xshift;\n            }\n            v = v/C;\n            // calculate the rstd\n            float s = 1.0f / sqrtf(v + eps);\n            // seek to the output position in out[b,t,:]\n            float* out_bt = out + b * T * C + t * C;\n            for (int i = 0; i < C; i++) {\n                float n = (s * (x[i] - m)); // normalized output\n                float o = n * weight[i] + bias[i]; // scale and shift it\n                out_bt[i] = o; // write\n            }\n            // cache the mean and rstd for the backward pass later\n            mean[b * T + t] = m;\n            rstd[b * T + t] = s;\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n\n// elementwise ops are nice and ez\n__global__ void residual_forward_kernel1(floatX* out, const floatX* inp1, const floatX* inp2, int N) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < N) {\n        out[idx] = (floatX)((float)inp1[idx] + (float)inp2[idx]);\n    }\n}\n\n// naive drag and drop implementation into kernel, parallelize over B,T, loop over C\n__global__ void layernorm_forward_kernel1(floatX* out, floatX* mean, floatX* rstd,\n                                          const floatX* inp, const floatX* weight, const floatX* bias,\n                                          int N, int C) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    float eps = 1e-5f;\n\n    if (idx < N) {\n        // seek to the input position inp[idx,:]\n        const floatX* x = inp + idx * C;\n        // calculate the mean\n        float m = 0.0f;\n        for (int i = 0; i < C; i++) {\n            m += (float)x[i];\n        }\n        m = m / C;\n        // calculate the variance (without any bias correction)\n        float v = 0.0f;\n        for (int i = 0; i < C; i++) {\n            float xshift = (float)x[i] - m;\n            v += xshift * xshift;\n        }\n        v = v / C;\n        // calculate the rstd\n        float s = 1.0f / sqrtf(v + eps);\n        // seek to the output position in out[idx,:]\n        floatX* out_idx = out + idx * C;\n        for (int i = 0; i < C; i++) {\n            float n = (s * ((float)x[i] - m)); // normalized output\n            float o = n * (float)weight[i] + (float)bias[i]; // scale and shift it\n            out_idx[i] = o; // write\n        }\n        // cache the mean and rstd for the backward pass later\n        mean[idx] = m;\n        rstd[idx] = s;\n    }\n}\n\n// naive fusion; uncoalesced access pattern leads to terrible performance\n__global__ void fused_residual_forward2(floatX* residual, floatX* normed, floatX* mean, floatX* rstd,\n                                        const floatX* inp1, const floatX* inp2,\n                                        const floatX* weight, const floatX* bias,\n                                        int N, int C) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if(idx > N) return;\n\n    // adjust pointers to current token\n    residual += C * idx;\n    normed += C * idx;\n    inp1 += C * idx;\n    inp2 += C * idx;\n\n    float eps = 1e-5f;\n\n    float m = 0.0f;\n    for(int c = 0; c < C; ++c) {\n        float out = (float)inp1[c] + (float)inp2[c];\n        m += out;\n        residual[c] = (floatX)out;\n    }\n\n    m = m / C;\n    float v = 0.0f;\n    for (int c = 0; c < C; c++) {\n        float xshift = (float)residual[c] - m;\n        v += xshift * xshift;\n    }\n    v = v / C;\n\n    // calculate the rstd\n    float s = 1.0f / sqrtf(v + eps);\n    for (int c = 0; c < C; c++) {\n        float n = (s * ((float)residual[c] - m)); // normalized output\n        float o = n * (float)weight[c] + (float)bias[c]; // scale and shift it\n        normed[c] = (floatX)o; // write\n    }\n    // cache the mean and rstd for the backward pass later\n    mean[idx] = (floatX)m;\n    rstd[idx] = (floatX)s;\n}\n\n// handle one token per warp for coalesced access\n__global__ void fused_residual_forward3(floatX* residual, floatX* normed, floatX* mean, floatX* rstd,\n                                        const floatX* inp1, const floatX* inp2,\n                                        const floatX* weight, const floatX* bias,\n                                        int N, int C) {\n    constexpr const int WarpSize = 32;\n    assert(blockDim.x == WarpSize);\n    int idx = blockIdx.x * blockDim.y + threadIdx.y;\n    if(idx > N) return;\n\n    // adjust pointers to current token\n    residual += C * idx;\n    normed += C * idx;\n    inp1 += C * idx;\n    inp2 += C * idx;\n\n    float eps = 1e-5f;\n    float m = 0.0f;\n    for(int c = threadIdx.x; c < C; c += WarpSize) {\n        float out = (float)inp1[c] + (float)inp2[c];\n        m += out;\n        residual[c] = out;\n    }\n\n    m = warpReduceSum(m);\n\n    m = m / C;\n    float v = 0.0f;\n    for(int c = threadIdx.x; c < C; c += WarpSize) {\n        float xshift = (float)residual[c] - m;\n        v += xshift * xshift;\n    }\n\n    v = warpReduceSum(v);\n    v = v / C;\n\n    // calculate the rstd\n    float s = 1.0f / sqrtf(v + eps);\n    for(int c = threadIdx.x; c < C; c += WarpSize) {\n        float n = (s * ((float)residual[c] - m)); // normalized output\n        float o = n * (float)weight[c] + (float)bias[c]; // scale and shift it\n        normed[c] = o; // write\n    }\n    // cache the mean and rstd for the backward pass later\n    if(threadIdx.x == 0) {\n        mean[idx] = m;\n        rstd[idx] = s;\n    }\n}\n\n// vectorized loading, single pass stats, streaming access and zigzag loop\n__global__ void fused_residual_forward_kernel4(floatX* residual, floatX* normed, floatX* mean, floatX* rstd,\n                                               const floatX* inp1, const floatX* inp2,\n                                               const floatX* weight, const floatX* bias,\n                                               int N, int C) {\n    using x128 = Packed128<floatX>;\n    constexpr const int WarpSize = 32;\n    assert(blockDim.x == WarpSize);\n    int idx = blockIdx.x * blockDim.y + threadIdx.y;\n    if(idx > N) return;\n\n    // adjust pointers to current token\n    residual += C * idx;\n    normed += C * idx;\n    inp1 += C * idx;\n    inp2 += C * idx;\n\n    const float eps = 1e-5f;\n    float sum = 0.0f;\n    float sum_sq = 0.0f;\n    int c = threadIdx.x * x128::size;\n    for(; c < C; c += WarpSize * x128::size) {\n        const x128 in1 = load128cs(inp1 + c);\n        const x128 in2 = load128cs(inp2 + c);\n        x128 out;\n        for(int k = 0; k < x128::size; ++k) {\n            out[k] = (floatX)((float)in1[k] + (float)in2[k]);\n            sum += (float)out[k];\n            sum_sq += (float)out[k] * (float)out[k];\n        }\n        store128(residual + c, out);\n    }\n\n    sum = warpReduceSum(sum);\n    sum_sq = warpReduceSum(sum_sq);\n\n    float m = sum / C;\n    float v = sum_sq / C - m * m;\n    float s = rsqrtf(v + eps);\n\n    c -= WarpSize * x128::size;\n    for(; c >= 0; c -= WarpSize * x128::size) {\n        const x128 res = load128cs(residual + c);\n        const x128 w = load128(weight + c);\n        const x128 b = load128(bias + c);\n        x128 out;\n        for(int k = 0; k < x128::size; ++k) {\n            float n = s * ((float)res[k] - m); // normalized output\n            float o = n * (float)w[k] + (float)b[k]; // scale and shift it\n            out[k] = o;\n        }\n\n        store128cs(normed + c, out);\n    }\n    // cache the mean and rstd for the backward pass later\n    if(threadIdx.x == 0) {\n        mean[idx] = m;\n        rstd[idx] = s;\n    }\n}\n\n// what do you want in shared memory? EVERYTHING!\n// thus, we no longer require zigzag loops and can do the numerically more stable variance estimation\n// needs special attention in the kernel launcher to ensure we have enough smem.\n__global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, floatX* mean, floatX* rstd,\n                                               const floatX* inp1, const floatX* inp2,\n                                               const floatX* weight, const floatX* bias,\n                                               int N, int C) {\n    constexpr const int WarpSize = 32;\n    assert(blockDim.x == WarpSize);\n\n    // load weights and biases into shared memory\n    // do this before we allow any threads to exit!\n    extern __shared__ char params[];\n    // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so\n    // let's keep everything as x128\n    x128* s_weight = reinterpret_cast<x128*>(params);\n    x128* s_bias = reinterpret_cast<x128*>(params) + (C / x128::size);\n    x128* s_res = reinterpret_cast<x128*>(params) + ((2 + threadIdx.y) * C / x128::size);\n\n    int sidx = (threadIdx.x + WarpSize * threadIdx.y) * x128::size;\n    for(int i = sidx; i < C; i += blockDim.y * WarpSize * x128::size) {\n        s_weight[i/x128::size] = load128(weight + i);\n        s_bias[i/x128::size] = load128(bias + i);\n    }\n    __syncthreads();\n\n    int idx = blockIdx.x * blockDim.y + threadIdx.y;\n    if(idx > N) return;\n\n    // adjust pointers to current token\n    residual += C * idx;\n    normed += C * idx;\n    inp1 += C * idx;\n    inp2 += C * idx;\n\n    const float eps = 1e-5f;\n    float sum = 0.0f;\n    for(int c = threadIdx.x * x128::size; c < C; c += WarpSize * x128::size) {\n        const x128 in1 = load128cs(inp1 + c);\n        const x128 in2 = load128cs(inp2 + c);\n        x128 out;\n        for(int k = 0; k < x128::size; ++k) {\n            out[k] = (floatX)((float)in1[k] + (float)in2[k]);\n            sum += (float)out[k];\n        }\n        store128cs(residual + c, out);\n        s_res[c / x128::size] = out;\n    }\n\n    sum = warpReduceSum(sum);\n    float m = sum / C;\n    float v = 0.f;\n\n    for(int c = threadIdx.x * x128::size; c < C; c += WarpSize * x128::size) {\n        const x128 res = s_res[c / x128::size];\n        for(int k = 0; k < x128::size; ++k) {\n            v += ((float)res[k] - m) * ((float)res[k] - m);\n        }\n    }\n\n    v = warpReduceSum(v) / C;\n    float s = rsqrtf(v + eps);\n\n    for(int c = threadIdx.x * x128::size; c < C; c += WarpSize * x128::size) {\n        const x128 res = s_res[c / x128::size];\n        const x128 w = s_weight[c / x128::size];\n        const x128 b = s_bias[c / x128::size];\n        x128 out;\n        for(int k = 0; k < x128::size; ++k) {\n            float n = s * ((float)res[k] - m); // normalized output\n            float o = n * (float)w[k] + (float)b[k]; // scale and shift it\n            out[k] = o;\n        }\n\n        store128cs(normed + c, out);\n    }\n    // cache the mean and rstd for the backward pass later\n    if(threadIdx.x == 0) {\n        mean[idx] = m;\n        rstd[idx] = s;\n    }\n}\n\n\n// using multiple warps per token, and keep threads persistent, so we never have to reload weights and biases\n// if we had one warp per token, though, this would require us to use a huge amount of shared memory. Therefore,\n// we use multiple warps per token; but generally we cannot use the entire block, because that would give too\n// little work per warp to be effective (each warp processes 256 bfloat16 elements, so for C=768 more than 3 warps\n// will just mean idle). Therefore, we add a z dimension, where warps with different z handle different tokens.\n// all this makes the launcher logic more complicated :(\n__global__ void fused_residual_forward_kernel6(floatX* residual, floatX* normed, floatX* mean, floatX* rstd,\n                                               const floatX* inp1, const floatX* inp2,\n                                               const floatX* weight, const floatX* bias,\n                                               int N, int C) {\n    constexpr const int WarpSize = 32;\n    assert(blockDim.x == WarpSize);\n\n    // load weights and biases into shared memory\n    // do this before we allow any threads to exit!\n    extern __shared__ char params[];\n    // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so\n    // let's keep everything as x128\n    // weights and biases are  shared among all tokens\n    x128* s_weight = reinterpret_cast<x128*>(params);\n    x128* s_bias = reinterpret_cast<x128*>(params + C * sizeof(floatX));\n    // residual output (input to layernorm) is independent for each sub-block indicates by threadIdx.z\n    x128* s_res = reinterpret_cast<x128*>(params + (2 + threadIdx.z) * C * sizeof(floatX));\n    // similarly, each sub-block needs its own reduction buffers\n    float* s_mean = reinterpret_cast<float*>(params + (2 + blockDim.z) * C * sizeof(floatX) + threadIdx.z * 32 * sizeof(float));\n    float* s_var = reinterpret_cast<float*>(params + (2 + blockDim.z) * C * sizeof(floatX) + 32 * sizeof(float) * (blockDim.z + threadIdx.z));\n\n    int cidx = (threadIdx.x + WarpSize * threadIdx.y) * x128::size;\n    int step = blockDim.y * WarpSize * x128::size;\n\n    for(int c = cidx; c < C; c += step) {\n        s_weight[c / x128::size] = load128(weight + c);\n        s_bias[c / x128::size] = load128(bias + c);\n    }\n\n    // the block-level reductions will cause sync before the first time we read these\n    // => no syncthreads needed here\n\n    // loop over all tokens\n    for(int tidx = blockIdx.x * blockDim.z + threadIdx.z; tidx < N; tidx += gridDim.x * blockDim.z) {\n        // adjust pointers to current token\n        floatX* residual_bt = residual + C * tidx;\n        floatX* normed_bt = normed + C * tidx;\n        const floatX* inp1_bt = inp1 + C * tidx;\n        const floatX* inp2_bt = inp2 + C * tidx;\n\n        const float eps = 1e-5f;\n        float sum = 0.0f;\n        for (int c = cidx; c < C; c += step) {\n            const x128 in1 = load128cs(inp1_bt + c);\n            const x128 in2 = load128cs(inp2_bt + c);\n            x128 out;\n            for (int k = 0; k < x128::size; ++k) {\n                out[k] = (float) in1[k] + (float) in2[k];\n                sum += (float) out[k];\n            }\n            store128cs(residual_bt + c, out);\n            s_res[c / x128::size] = out;\n        }\n        sum = warpReduceSum(sum);\n        if(threadIdx.x == 0) {\n            s_mean[threadIdx.y] = sum;\n        }\n        __syncthreads();\n        float m = warpReduceSum(threadIdx.x < blockDim.y ? s_mean[threadIdx.x] : 0.f) / C;\n        // normally, we'd syncthread here to make sure that no warp is already at the next\n        // iteration of the loop, messing with s_mean. The fact that we interleave s_mean and s_var means\n        // we don't need these additional syncs.\n        float v = 0.f;\n\n        for (int c = cidx; c < C; c += step) {\n            const x128 res = s_res[c / x128::size];\n            for (int k = 0; k < x128::size; ++k) {\n                v += ((float) res[k] - m) * ((float) res[k] - m);\n            }\n        }\n\n        v = warpReduceSum(v);\n        if(threadIdx.x == 0) {\n            s_var[threadIdx.y] = v;\n        }\n        __syncthreads();\n        v = warpReduceSum(threadIdx.x < blockDim.y ? s_var[threadIdx.x] : 0.f) / C;\n        float s = rsqrtf(v + eps);\n\n        for (int c = cidx; c < C; c += step) {\n            const x128 res = s_res[c / x128::size];\n            const x128 w = s_weight[c / x128::size];\n            const x128 b = s_bias[c / x128::size];\n            x128 out;\n            for (int k = 0; k < x128::size; ++k) {\n                float n = s * ((float) res[k] - m); // normalized output\n                float o = n * (float) w[k] + (float) b[k]; // scale and shift it\n                out[k] = o;\n            }\n\n            store128(normed_bt + c, out);\n        }\n        // cache the mean and rstd for the backward pass later\n        if (threadIdx.x == 0 && threadIdx.y == 0) {\n            mean[tidx] = m;\n            rstd[tidx] = s;\n        }\n    }\n}\n\n\n\n// ----------------------------------------------------------------------------\n// kernel launcher\n\nvoid fused_residual_forward1(floatX* residual, floatX* normed, floatX* mean, floatX* rstd,\n                             const floatX* inp1, const floatX* inp2,\n                             const floatX* weight, const floatX* bias,\n                             int N, int C, const int block_size) {\n    const int grid_size_resid = ceil_div(N * C, block_size);\n    residual_forward_kernel1<<<grid_size_resid, block_size>>>(residual, inp1, inp2, N*C);\n    cudaCheck(cudaGetLastError());\n    const int grid_size_ln = ceil_div(N, block_size);\n    layernorm_forward_kernel1<<<grid_size_ln, block_size>>>(normed, mean, rstd, residual, weight, bias, N, C);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid fused_residual_forward2(floatX* residual, floatX* normed, floatX* mean, floatX* rstd,\n                             const floatX* inp1, const floatX* inp2,\n                             const floatX* weight, const floatX* bias,\n                             int N, int C, const int block_size) {\n    const int grid_size = ceil_div(N, (int)(block_size));\n    fused_residual_forward2<<<grid_size, block_size>>>(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid fused_residual_forward3(floatX* residual, floatX* normed, floatX* mean, floatX* rstd,\n                             const floatX* inp1, const floatX* inp2,\n                             const floatX* weight, const floatX* bias,\n                             int N, int C, const int block_size) {\n    int block_y = block_size / 32;\n    const int grid_size = ceil_div(N, block_y);\n    fused_residual_forward3<<<grid_size, dim3(32, block_y)>>>(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid fused_residual_forward4(floatX* residual, floatX* normed, floatX* mean, floatX* rstd,\n                             const floatX* inp1, const floatX* inp2,\n                             const floatX* weight, const floatX* bias,\n                             int N, int C, const int block_size) {\n    int block_y = block_size / 32;\n    const int grid_size = ceil_div(N, block_y);\n    fused_residual_forward_kernel4<<<grid_size, dim3(32, block_y)>>>(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid fused_residual_forward5(floatX* residual, floatX* normed, floatX* mean, floatX* rstd,\n                             const floatX* inp1, const floatX* inp2,\n                             const floatX* weight, const floatX* bias,\n                             int N, int C, const int block_size) {\n    int block_y = block_size / 32;\n    const int grid_size = ceil_div(N, block_y);\n    size_t smem = (2 + block_y) * C * sizeof(floatX);\n\n    // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute\n    // this may fail, in which case we fall back to the smem free implementation.\n    cudaCheck(cudaGetLastError());\n    auto status = cudaFuncSetAttribute(fused_residual_forward_kernel5, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);\n    cudaGetLastError();\n    if(status == cudaSuccess) {\n        fused_residual_forward_kernel5<<<grid_size, dim3(32, block_y), smem>>>(residual, normed, mean, rstd, inp1, inp2,\n                                                                               weight, bias, N, C);\n    } else {\n        fused_residual_forward_kernel4<<<grid_size, dim3(32, block_y)>>>(residual, normed, mean, rstd, inp1, inp2,\n                                                                         weight, bias, N, C);\n    }\n    cudaCheck(cudaGetLastError());\n}\n\nvoid fused_residual_forward6(floatX* residual, floatX* normed, floatX* mean, floatX* rstd,\n                             const floatX* inp1, const floatX* inp2,\n                             const floatX* weight, const floatX* bias,\n                             int N, int C, const int block_size) {\n    int warps_per_token = max(1, C / Packed128<floatX>::size / 32);\n    int total_warps = block_size / 32;\n    int block_z = max(1, total_warps / warps_per_token);\n    int block_y = max(1, total_warps / block_z);\n    size_t smem = (2 + block_z) * C * sizeof(floatX) + 64 * sizeof(float) * block_z;\n\n    // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute\n    // this may fail, in which case we fall back to the smem free implementation.\n    cudaCheck(cudaGetLastError());\n    auto status = cudaFuncSetAttribute(fused_residual_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);\n    cudaGetLastError();\n    if(status == cudaSuccess) {\n        const int num_blocks = max(1, cuda_threads_per_SM * cuda_num_SMs / block_size);\n        fused_residual_forward_kernel6<<<num_blocks, dim3(32, block_y, block_z), smem>>>(residual, normed, mean, rstd, inp1, inp2,\n                                                                               weight, bias, N, C);\n    } else {\n        const int grid_size = ceil_div(N, total_warps);\n        fused_residual_forward_kernel4<<<grid_size, dim3(32, total_warps)>>>(residual, normed, mean, rstd, inp1, inp2,\n                                                                         weight, bias, N, C);\n    }\n    cudaCheck(cudaGetLastError());\n}\n\n// kernel version dispatch\nvoid fused_residual_forward(int kernel_num, floatX* residual, floatX* normed, floatX* mean, floatX* rstd,\n                            const floatX* inp1, const floatX* inp2,\n                            const floatX* weight, const floatX* bias,\n                            int N, int C, const int block_size) {\n    switch (kernel_num) {\n        case 1:\n            fused_residual_forward1(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C, block_size);\n            break;\n        case 2:\n            fused_residual_forward2(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C, block_size);\n            break;\n        case 3:\n            fused_residual_forward3(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C, block_size);\n            break;\n        case 4:\n            fused_residual_forward4(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C, block_size);\n            break;\n        case 5:\n            fused_residual_forward5(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C, block_size);\n            break;\n        case 6:\n            fused_residual_forward6(residual, normed, mean, rstd, inp1, inp2, weight, bias, N, C, block_size);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n// ----------------------------------------------------------------------------\n\nint main(int argc, const char **argv) {\n    setup_main();\n\n    int B = 8;\n    int T = 1024;\n    int C = 768;\n\n    // read kernel_num from command line\n    int kernel_num = 1;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    // create host memory of random numbers\n    float* residual = (float*)malloc(B * T * C * sizeof(float));\n    float* normed = (float*)malloc(B * T * C * sizeof(float));\n    float* inp1 = make_random_float(B * T * C);\n    float* inp2 = make_random_float(B * T * C);\n    float* mean = (float*)malloc(B * T * sizeof(float));\n    float* rstd = (float*)malloc(B * T * sizeof(float));\n    float* weight = make_random_float(C);\n    float* bias = make_random_float(C);\n    \n    // move to GPU\n    floatX* d_residual;\n    floatX* d_normed;\n    floatX* d_inp1;\n    floatX* d_inp2;\n    floatX* d_mean;\n    floatX* d_rstd;\n    floatX* d_weight;\n    floatX* d_bias;\n    cudaCheck(cudaMalloc(&d_residual, B * T * C * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_normed, B * T * C * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_inp1, B * T * C * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_inp2, B * T * C * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_mean, B * T * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_rstd, B * T * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_weight, C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_bias, C * sizeof(float)));\n    cudaCheck(memcpy_convert(d_inp1, inp1, B * T * C));\n    cudaCheck(memcpy_convert(d_inp2, inp2, B * T * C));\n    cudaCheck(memcpy_convert(d_weight, weight, C));\n    cudaCheck(memcpy_convert(d_bias, bias, C));\n\n    // first check the correctness of the kernel\n    residual_forward_cpu(residual, inp1, inp2, B * T * C);\n    layernorm_forward_cpu(normed, mean, rstd, residual, weight, bias, B, T, C);\n\n    // time the kernel at different block sizes\n    int block_sizes[] = {32, 64, 128, 256, 512, 1024};\n\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        printf(\"Checking block size %d.\\n\", block_size);\n        cudaCheck(cudaMemset(d_residual, 0, B * T * C * sizeof(floatX)));\n        fused_residual_forward(kernel_num, d_residual, d_normed, d_mean, d_rstd, d_inp1, d_inp2, d_weight, d_bias,\n                               B*T, C, block_size);\n        float tol = std::is_same_v<floatX, float> ? 1e-5 : 5e-2;\n        validate_result(d_residual, residual, \"residual\", B * T * C, tol);\n        validate_result(d_mean, mean, \"mean\", B * T, tol);\n        validate_result(d_rstd, rstd, \"rstd\", B * T, tol);\n        validate_result(d_normed, normed, \"normed\", B * T * C, tol);\n    }\n\n    printf(\"All results match. Starting benchmarks.\\n\\n\");\n\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n\n        int repeat_times = 1000;\n        float elapsed_time = benchmark_kernel(repeat_times, fused_residual_forward, kernel_num,\n                                              d_residual, d_normed, d_mean, d_rstd, d_inp1, d_inp2, d_weight, d_bias,\n                                              B*T, C, block_size\n                                              );\n\n        // napkin math: estimate the memory bandwidth achieved\n        // for each (B,T,C) output element, we do 2 reads and 2 writes, plus 2 BT writes for mean/rstd\n        // and e.g. A100 40GB PCIe is advertised at 1,555GB/s\n        long memory_ops = B * T * (C * 4 + 2) * sizeof(floatX);\n        float memory_bandwidth = memory_ops / elapsed_time / 1e6;\n        float toks_per_msec = B * T / elapsed_time / 1e3;\n\n        printf(\"block_size %4d | time %.4f ms | bandwidth %.2f GB/s | elements: %.2f ktok/ms\\n\",\n               block_size, elapsed_time, memory_bandwidth, toks_per_msec);\n    }\n\n    // free memory\n    free(residual);\n    free(normed);\n    free(mean);\n    free(rstd);\n    free(weight);\n    free(bias);\n    free(inp1);\n    free(inp2);\n    cudaCheck(cudaFree(d_residual));\n    cudaCheck(cudaFree(d_normed));\n    cudaCheck(cudaFree(d_mean));\n    cudaCheck(cudaFree(d_rstd));\n    cudaCheck(cudaFree(d_weight));\n    cudaCheck(cudaFree(d_bias));\n    cudaCheck(cudaFree(d_inp1));\n    cudaCheck(cudaFree(d_inp2));\n\n    return 0;\n}\n"
  },
  {
    "path": "dev/cuda/gelu_backward.cu",
    "content": "/*\nKernels for gelu backward pass.\n\nCompile example:\nnvcc -O3 --use_fast_math -lcublas -lcublasLt gelu_backward.cu -o gelu_backward\n\nIf encountering \"error: identifier \"M_PI\" is undefined\", add the following lines to the top of the file:\n\n#define _USE_MATH_DEFINES\n#include <math.h>  OR  #include <cmath>\n\nversion 1 is naive port from CPU code to kernel\n./gelu_backward 1\n\nversion 2 uses the Packed128 data structure\n./gelu_backward 2\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <cuda_runtime.h>\n\n#define ENABLE_BF16\n#include \"common.h\"\n\n// ----------------------------------------------------------------------------\n// CPU code reference\n\n#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI)\n\nvoid gelu_backward_cpu(float* dinp, const float* inp, const float* dout, const int N) {\n    for (int i = 0; i < N; i++) {\n        float x = inp[i];\n        float cube = 0.044715f * x * x * x;\n        float tanh_arg = GELU_SCALING_FACTOR * (x + cube);\n        float tanh_out = tanhf(tanh_arg);\n        float coshf_out = coshf(tanh_arg);\n        float sech_out = 1.0f / (coshf_out * coshf_out);\n        float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);\n        dinp[i] = (floatX)(local_grad * (float)dout[i]);\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n\n// elementwise ops are nice and ez\n__global__ void gelu_backward1(floatX* dinp, const floatX* inp, const floatX* dout, int N) {\n    int i = blockIdx.x * blockDim.x + threadIdx.x;\n    if (i < N) {\n        float x = (float)inp[i];\n        float cube = 0.044715f * x * x * x;\n        float tanh_arg = GELU_SCALING_FACTOR * (x + cube);\n        float tanh_out = tanhf(tanh_arg);\n        float coshf_out = coshf(tanh_arg);\n        float sech_out = 1.0f / (coshf_out * coshf_out);\n        float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);\n        dinp[i] = (floatX)(local_grad * (float)dout[i]);\n    }\n}\n\n__global__ void gelu_backward2(floatX* dinp, const floatX* inp, const floatX* dout, const int N) {\n    int i = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;\n    if (i < N) {\n        x128 packed_dinp;\n        x128 packed_inp = load128cs(inp + i);\n        x128 packed_dout = load128cs(dout + i);\n        for (int k = 0; k < packed_inp.size; ++k) {\n            float x = (float)packed_inp[k];\n            float cube = 0.044715f * x * x * x;\n            float tanh_arg = GELU_SCALING_FACTOR * (x + cube);\n            float tanh_out = tanhf(tanh_arg);\n            float coshf_out = coshf(tanh_arg);\n            float sech_out = 1.0f / (coshf_out * coshf_out);\n            float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);\n            packed_dinp[k] = (floatX)(local_grad * (float)packed_dout[k]);\n        }\n\n        store128(dinp + i, packed_dinp);\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launcher\n\nvoid gelu_backward1(floatX* dinp, const floatX* inp, const floatX* dout, int N, const int block_size) {\n    const int grid_size = ceil_div(N, block_size);\n    gelu_backward1<<<grid_size, block_size>>>(dinp, inp, dout, N);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid gelu_backward2(floatX* dinp, const floatX* inp, const floatX* dout, int N, const int block_size) {\n    const int grid_size = ceil_div(N, block_size * x128::size);\n    gelu_backward2<<<grid_size, block_size>>>(dinp, inp, dout, N);\n    cudaCheck(cudaGetLastError());\n}\n\n// kernel version dispatch\nvoid gelu_backward(int kernel_num,\n                  floatX* dinp, \n                  const floatX* inp, \n                  const floatX* dout,\n                  int B, int T, int C,\n                  int block_size) {\n    switch (kernel_num) {\n        case 1:\n            gelu_backward1(dinp, inp, dout, B * T * C, block_size);\n            break;\n        case 2:\n            gelu_backward2(dinp, inp, dout, B * T * C, block_size);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n// ----------------------------------------------------------------------------\n\nint main(int argc, char **argv) {\n    setup_main();\n\n    int B = 8;\n    int T = 1024;\n    int C = 768;\n\n    // create host memory of random numbers\n    float* dinp = (float*)malloc(B * T * C * sizeof(float));\n    float* inp = make_random_float(B * T * C);\n    float* dout = make_random_float(B * T * C);\n\n    // read kernel_num from command line\n    int kernel_num = 1;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    // first check the correctness of the kernel\n    gelu_backward_cpu(dinp, inp, dout, B * T * C);\n\n    // move to GPU\n    floatX* d_dinp;\n    floatX* d_inp;\n    floatX* d_dout;\n    cudaCheck(cudaMalloc(&d_dinp, B * T * C * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_dout, B * T * C * sizeof(floatX)));\n\n    cudaCheck(memcpy_convert(d_inp, inp, B * T * C));\n    cudaCheck(memcpy_convert(d_dout, dout, B * T * C));\n\n    // time the kernel at different block sizes\n    int block_sizes[] = {32, 64, 128, 256, 512, 1024};\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        printf(\"Checking block size %d.\\n\", block_size);\n        gelu_backward(kernel_num, d_dinp, d_inp, d_dout, B, T, C, block_size);\n#if !defined(ENABLE_BF16) && !defined(ENABLE_FP16)\n        float tol = 1e-5;\n#else\n        float tol = 1e-2f;\n#endif\n        validate_result(d_dinp, dinp, \"dinp\", B * T * C, tol);\n    }\n\n    printf(\"All results match. Starting benchmarks.\\n\\n\");\n\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n\n        int repeat_times = 1000;\n\n        float elapsed_time = benchmark_kernel(repeat_times, gelu_backward,\n                                              kernel_num, d_dinp, d_inp, d_dout,\n                                              B, T, C, block_size);\n\n        // napkin math: estimate the memory bandwidth achieved\n        // for each (B,T,C) output element, we do 1 read and 1 write, 4 bytes each\n        // and e.g. A100 40GB PCIe is advertised at 1,555GB/s\n        long memory_ops = B * T * C * 2 * 4;\n        float memory_bandwidth = memory_ops / elapsed_time / 1e6;\n\n        printf(\"block_size %4d | time %.4f ms | bandwidth %.2f GB/s\\n\", block_size, elapsed_time, memory_bandwidth);\n    }\n\n    // free memory\n    free(dinp);\n    free(inp);\n    free(dout);\n    cudaCheck(cudaFree(d_dinp));\n    cudaCheck(cudaFree(d_inp));\n    cudaCheck(cudaFree(d_dout));\n    return 0;\n}\n"
  },
  {
    "path": "dev/cuda/gelu_forward.cu",
    "content": "/*\nKernels for gelu forward pass.\n\nCompile example:\nnvcc -O3 --use_fast_math -lcublas -lcublasLt gelu_forward.cu -o gelu_forward\n\nIf encountering \"error: identifier \"M_PI\" is undefined\", add the following lines to the top of the file:\n\n#define _USE_MATH_DEFINES\n#include <math.h>  OR  #include <cmath>\n\nversion 1 is naive CPU port\n./gelu_forward 1\n\nversion 2 is bfloat16 with the Packed128 data structure\n./gelu_forward 2\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <cuda_runtime.h>\n\n#define ENABLE_BF16\n#include \"common.h\"\n\n// ----------------------------------------------------------------------------\n// CPU code reference\n\n#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI)\n\nvoid gelu_forward_cpu(float* out, const float* inp, int N) {\n    for (int i = 0; i < N; i++) {\n        float x = inp[i];\n        float cube = 0.044715f * x * x * x;\n        out[i] = 0.5f * x * (1.0f + tanhf(GELU_SCALING_FACTOR * (x + cube)));\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n\n// elementwise ops are nice and ez\n__global__ void gelu_forward_kernel1(floatX* out, const floatX* inp, int N) {\n    int i = blockIdx.x * blockDim.x + threadIdx.x;\n    if (i < N) {\n        float xi = inp[i];\n        float cube = 0.044715f * xi * xi * xi;\n        out[i] = 0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube)));\n    }\n}\n\n// elementwise ops are nice and ez\n__global__ void gelu_forward_kernel2(floatX* out, const floatX* inp, int N) {\n    int i = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;\n    if (i < N) {\n        x128 packed_out;\n        x128 packed_inp = load128cs(inp + i); // load and do not keep in cache\n        for(int k = 0; k < packed_inp.size; ++k) {\n            float xi = (float)packed_inp[k];\n            float cube = 0.044715f * xi * xi * xi;\n            packed_out[k] = (floatX)(0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube))));\n        }\n        // store instead of storecs (without cache streaming) in case it is useful for the\n        // data to be in the cache for the next operation after this GeLU\n        store128(out + i, packed_out);\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launcher\n\nvoid gelu_forward1(floatX* out, const floatX* inp, int N, const int block_size) {\n    const int grid_size = ceil_div(N, block_size);\n    gelu_forward_kernel1<<<grid_size, block_size>>>(out, inp, N);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid gelu_forward2(floatX* out, const floatX* inp, int N, const int block_size) {\n    const int grid_size = ceil_div(N, block_size * x128::size);\n    gelu_forward_kernel2<<<grid_size, block_size>>>(out, inp, N);\n    cudaCheck(cudaGetLastError());\n}\n\n// kernel version dispatch\nvoid gelu_forward(int kernel_num,\n                  floatX* out,\n                  const floatX* inp,\n                  int B, int T, int C,\n                  int block_size) {\n    switch (kernel_num) {\n        case 1:\n            gelu_forward1(out, inp, B * T * C, block_size);\n            break;\n        case 2:\n            gelu_forward2(out, inp, B * T * C, block_size);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n// ----------------------------------------------------------------------------\n\nint main(int argc, const char **argv) {\n    setup_main();\n\n    int B = 8;\n    int T = 1024;\n    int C = 768;\n\n    // create host memory of random numbers\n    float* out = (float*)malloc(B * T * C * sizeof(float));\n    float* inp = make_random_float(B * T * C);\n\n    // read kernel_num from command line\n    int kernel_num = 1;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    // first check the correctness of the kernel\n    gelu_forward_cpu(out, inp, B * T * C);\n\n    // move to GPU\n    floatX* d_out;\n    floatX* d_inp;\n    cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(floatX)));\n    cudaCheck(memcpy_convert(d_inp, inp, B * T * C));\n\n    // time the kernel at different block sizes\n    int block_sizes[] = {32, 64, 128, 256, 512, 1024};\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        printf(\"Checking block size %d.\\n\", block_size);\n        gelu_forward(kernel_num, d_out, d_inp, B, T, C, block_size);\n#if !defined(ENABLE_BF16) && !defined(ENABLE_FP16)\n        float tol = 1e-5;\n#else\n        float tol = 1e-2f;\n#endif\n        validate_result(d_out, out, \"out\", B * T * C, tol);\n    }\n\n    printf(\"All results match. Starting benchmarks.\\n\\n\");\n\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n\n        int repeat_times = 1000;\n\n        float elapsed_time = benchmark_kernel(repeat_times, gelu_forward,\n                                              kernel_num, d_out, d_inp,\n                                              B, T, C, block_size);\n\n        // napkin math: estimate the memory bandwidth achieved\n        // for each (B,T,C) output element, we do 1 read and 1 write, 4 bytes each\n        // and e.g. A100 40GB PCIe is advertised at 1,555GB/s\n        long memory_ops = B * T * C * 2 * (int)sizeof(floatX);\n        float memory_bandwidth = memory_ops / elapsed_time / 1e6;\n\n        printf(\"block_size %4d | time %.4f ms | bandwidth %.2f GB/s\\n\", block_size, elapsed_time, memory_bandwidth);\n    }\n\n    // free memory\n    free(out);\n    free(inp);\n\n    cudaCheck(cudaFree(d_out));\n    cudaCheck(cudaFree(d_inp));\n    return 0;\n}"
  },
  {
    "path": "dev/cuda/global_norm.cu",
    "content": "/*\nKernels for a global norm.\nGlobal norm in this context means that we want to calculate a single norm cooperatively using all avalailable SMs, instead\n of multiple norms that can be handled by separate blocks.\n\nCompile example:\nnvcc -O3 --use_fast_math global_norm.cu -o global_norm\n*/\n\n\n#include <assert.h>\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n\n// turn on bf16 as default, done up here for now\n#define ENABLE_BF16\n#include \"common.h\"\n\ncudaDeviceProp deviceProp;\n\nfloat global_norm_cpu(const float* data, size_t count) {\n    // accumulate in double so we have an accurate numerical reference\n    double acc = 0.0;\n    for(size_t i = 0; i < count; ++i) {\n        acc  += (double)data[i] * (double)data[i];\n    }\n    return (float)acc;\n}\n\n\ntemplate<class T>\n__global__ void norm_kernel1(float* out, const T* data, size_t count) {\n    // we want as few atomics as possible, so each block tries to do\n    // the maximum amount of work (so no fixed chunk, but instead iterating\n    // until we run out of data), and then we reduce inside the block\n    // and finally have just one atomic per block.\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n\n    __shared__ float block_result[32];\n\n    // out will be updated atomically from all thread blocks\n    size_t index = threadIdx.x + blockDim.x * blockIdx.x;\n    size_t grid_width = blockDim.x * gridDim.x;\n    float accumulator = 0.f;\n    for(size_t i = index; i < count; i += grid_width) {\n        accumulator += (float)data[i] * (float)data[i];\n    }\n    // warp-level reduce\n    float warp_result = cg::reduce(warp, accumulator, cg::plus<float>{});\n    block_result[warp.meta_group_rank()] = warp_result;\n    block.sync();\n    if(warp.meta_group_rank() == 0) {\n        float gather = warp.thread_rank() < warp.meta_group_size() ? block_result[warp.thread_rank()] : 0.f;\n        float block_sum = cg::reduce(warp, gather, cg::plus<float>{});\n        if(warp.thread_rank() ==  0) {\n            atomicAdd(out, block_sum);\n        }\n    }\n}\n\ntemplate<class T>\n__global__ void norm_kernel2(float* out, const T* data, size_t count) {\n    // concrete example for an A100 GPU (108 SMs, 2048 max threads each)\n    // so there are 2048 * 108 = 221,184 threads total\n    // say the block_size is 512, then we would launch 432 blocks in total\n    // say num_params is ~100M, each thread will process ~500 elements\n    // warps reduce with warp-level reduce, we have 221,184/32 = 6,912 warps\n    // and then each warp atomicAdd's to global memory, total of 6,912 atomics\n\n    // no shared memory; but one atomic per warp instead of per block\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n\n    // out will be updated atomically from all thread blocks\n    size_t index = threadIdx.x + blockDim.x * blockIdx.x;\n    size_t grid_width = blockDim.x * gridDim.x;\n    float accumulator = 0.f;\n    for(size_t i = index; i < count; i += grid_width) {\n        accumulator += (float)data[i] * (float)data[i];\n    }\n\n    // warp-level reduce\n    float warp_result = cg::reduce(warp, accumulator, cg::plus<float>{});\n    // and atomic in global buffer\n    if(warp.thread_rank() == 0) {\n        atomicAdd(out, warp_result);\n    }\n}\n\ntemplate<class T>\n__global__ void norm_kernel3(float* out, const T* data, size_t count) {\n    size_t index = blockIdx.x * blockDim.x + threadIdx.x;\n    size_t grid_width = blockDim.x * gridDim.x;\n    float accumulator = 0.f;\n    for(size_t i = index; i < count; i += grid_width) {\n        accumulator += (float)data[i] * (float)data[i];\n    }\n    // block-level reduce\n    float block_sum = blockReduce<warpReduceSum>(accumulator);\n    if(threadIdx.x == 0) {\n        atomicAdd(out, block_sum);\n    }\n}\n\n// Same as kernel3 but without atomic adds -> this allows us to have determinism due to the\n// non associativity of floating point operations. Roughly same performance as kernel3.\ntemplate<class T>\n__global__ void norm_kernel4(float* out, const T* data, size_t count) {\n    size_t index = blockIdx.x * blockDim.x + threadIdx.x;\n    size_t grid_width = blockDim.x * gridDim.x;\n    float accumulator = 0.f;\n    for(size_t i = index; i < count; i += grid_width) {\n        accumulator += (float)data[i] * (float)data[i];\n    }\n    // block-level reduce\n    float block_sum = blockReduce<warpReduceSum>(accumulator);\n    // each block accumulates its partial sum to out[blockIdx.x]\n    // we want to avoid using atomic add here so we combine this kernel with the aggregate kernel call\n    // that sums up the partial block sums\n    if(threadIdx.x == 0) {\n        out[blockIdx.x] = block_sum;\n    }\n}\n\n__global__ void global_norm_aggregate_kernel(float* out, size_t count) {\n    size_t index = threadIdx.x;\n    // grab block sums from the previous kernel, use 0. as the neutral sum element\n    float block_sum = (index < count) ? out[index] : 0.f;\n    float sum = blockReduce<warpReduceSum>(block_sum);\n    if(threadIdx.x == 0) {\n        out[0] = sum;  // out[0] ends up with the final norm squared\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launchers\n\ntemplate<typename T>\nvoid global_norm1(float* out, const T* values, size_t count, int block_size) {\n    // launch just enough blocks to fill the grid. deliberately no DIV_CEIL.\n    // having one block less than possible is a tiny performance hit, having\n    // one block too many is catastrophic, since it only can start once all the other\n    // blocks finish. anyway, I think cuda_threads_per_SM should be a multiple of 512\n    // on all gpus, so the division really is going to be exact.\n    const int grid_size = cuda_threads_per_SM * cuda_num_SMs / block_size;\n    assert(grid_size > 0);      // gives a better error than letting the call below fail\n    norm_kernel1<<<grid_size, block_size>>>(out, values, count);\n    cudaCheck(cudaGetLastError());\n}\n\ntemplate<typename T>\nvoid global_norm2(float* out, const T* values, size_t count, int block_size) {\n    // ditto\n    const int grid_size = cuda_threads_per_SM * cuda_num_SMs / block_size;\n    assert(grid_size > 0);      // gives a better error than letting the call below fail\n    norm_kernel2<<<grid_size, block_size>>>(out, values, count);\n    cudaCheck(cudaGetLastError());\n}\n\ntemplate<typename T>\nvoid global_norm3(float* out, const T* values, size_t count, int block_size) {\n    // launch just enough blocks to fill the grid. deliberately no DIV_CEIL.\n    // having one block less than possible is a tiny performance hit, having\n    // one block too many is catastrophic, since it only can start once all the other\n    // blocks finish. anyway, I think cuda_threads_per_SM should be a multiple of 512\n    // on all gpus, so the division really is going to be exact.\n    const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size;\n    assert(grid_size > 0);  // gives a better error than letting the call below fail\n    norm_kernel3<<<grid_size, block_size>>>(out, values, count);\n    cudaCheck(cudaGetLastError());\n}\n\ntemplate<typename T>\nvoid global_norm4(float* out, const T* values, size_t count, int block_size) {\n    if (block_size <= 64) {\n        block_size = 128;  // to avoid triggering the assert below\n    }\n    // launch just enough blocks to fill the grid. deliberately no DIV_CEIL.\n    // having one block less than possible is a tiny performance hit, having\n    // one block too many is catastrophic, since it only can start once all the other\n    // blocks finish. anyway, I think cuda_threads_per_SM should be a multiple of 512\n    // on all gpus, so the division really is going to be exact.\n    const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size;\n    assert(grid_size > 0);      // gives a better error than letting the call below fail\n    assert(grid_size < 1024);  // we want to later accumulate the block sums in a single block\n    norm_kernel4<<<grid_size, block_size>>>(out, values, count);\n    cudaCheck(cudaGetLastError());\n    global_norm_aggregate_kernel<<<1, 1024>>>(out, grid_size);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid global_norm(int kernel_num, float* out, const floatX* values, size_t count, int block_size) {\n    switch (kernel_num) {\n        case 1:\n            return global_norm1(out, values, count, block_size);\n        case 2:\n            return global_norm2(out, values, count, block_size);\n        case 3:\n            return global_norm3(out, values, count, block_size);\n        case 4:\n            return global_norm4(out, values, count, block_size);\n    }\n}\n\nint main(int argc, const char **argv) {\n    setup_main();\n    cudaGetDeviceProperties(&deviceProp, 0);\n\n    int C = 768;\n    int L = 12;\n\n    size_t num_params = (size_t)(C * 4*C + C*C) * 2 * L;\n\n    // create host memory of random numbers\n    float* inp = make_random_float(num_params);\n    // scale them down\n    for(size_t i = 0; i < num_params; ++i) {\n        inp[i] *= 1e-3;\n    }\n\n    // read kernel_num from command line\n    int kernel_num = 1;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    // first check the correctness of the kernel\n    float out = global_norm_cpu(inp, num_params);\n\n    // move to GPU\n    float* d_out;\n    floatX* d_inp;\n    cudaCheck(cudaMalloc(&d_out,  1024 * sizeof(float)));  // 1024 needed for kernel 4\n    cudaCheck(cudaMalloc(&d_inp, num_params * sizeof(floatX)));\n    cudaCheck(memcpy_convert(d_inp, inp, num_params));\n\n    int block_sizes[] = {32, 64, 128, 256, 512, 768, 1024};\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        printf(\"Checking block size %d.\\n\", block_size);\n        cudaCheck(cudaMemset(d_out, 0, sizeof(float)));\n        global_norm(kernel_num, d_out, d_inp, num_params, block_size);\n        validate_result(d_out, &out, \"out\", 1, 1e-2f);\n    }\n\n    printf(\"All results match. Starting benchmarks.\\n\\n\");\n\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n\n        int repeat_times = 1000;\n\n        float elapsed_time = benchmark_kernel(repeat_times, global_norm,\n                                              kernel_num, d_out, d_inp,\n                                              num_params, block_size);\n        size_t memory_ops = num_params * sizeof(floatX);\n        float memory_bandwidth = memory_ops / elapsed_time / 1e6;\n\n        printf(\"block_size %4d | time %.4f ms | bandwidth %.2f GB/s\\n\", block_size, elapsed_time, memory_bandwidth);\n    }\n\n    // free memory\n    free(inp);\n    cudaCheck(cudaFree(d_out));\n    cudaCheck(cudaFree(d_inp));\n}"
  },
  {
    "path": "dev/cuda/layernorm_backward.cu",
    "content": "/*\nKernels for layernorm backward pass.\n\nCompile example:\nnvcc -O3 --use_fast_math -lcublas -lcublasLt layernorm_backward.cu -o layernorm_backward\n\nversion 1 is naive port from CPU code to kernel: parallelizes over B,T, loops over C\n./layernorm_backward 1\n\nversion 2 moves a lot of reduction to shared memory over global memory\n./layernorm_backward 2\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <cuda_runtime.h>\n#include <assert.h>\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n\n#define ENABLE_BF16\n#include \"common.h\"\n\n// ----------------------------------------------------------------------------\n// CPU code reference\n\nvoid layernorm_forward_cpu(float* out, float* mean, float* rstd,\n                       const float* inp, const float* weight, const float* bias,\n                       int B, int T, int C) {\n    // reference: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html\n    // both inp and out are (B,T,C) of the activations\n    // mean and rstd are (B,T) buffers, to be used later in backward pass\n    // at each position (b,t) of the input, the C-dimensional vector\n    // of activations gets normalized, then scaled and shifted\n    float eps = 1e-5f;\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            // seek to the input position inp[b,t,:]\n            const float* x = inp + b * T * C + t * C;\n            // calculate the mean\n            float m = 0.0f;\n            for (int i = 0; i < C; i++) {\n                m += x[i];\n            }\n            m = m/C;\n            // calculate the variance (without any bias correction)\n            float v = 0.0f;\n            for (int i = 0; i < C; i++) {\n                float xshift = x[i] - m;\n                v += xshift * xshift;\n            }\n            v = v/C;\n            // calculate the rstd (reciprocal standard deviation)\n            float s = 1.0f / sqrtf(v + eps);\n            // seek to the output position in out[b,t,:]\n            float* out_bt = out + b * T * C + t * C;\n            for (int i = 0; i < C; i++) {\n                float n = (s * (x[i] - m)); // normalize\n                float o = n * weight[i] + bias[i]; // scale and shift\n                out_bt[i] = o; // write\n            }\n            // cache the mean and rstd for the backward pass later\n            mean[b * T + t] = m;\n            rstd[b * T + t] = s;\n        }\n    }\n}\n\nvoid layernorm_backward_cpu(float* dinp, float* dweight, float* dbias,\n                        const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd,\n                        int B, int T, int C) {\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            const float* dout_bt = dout + b * T * C + t * C;\n            const float* inp_bt = inp + b * T * C + t * C;\n            float* dinp_bt = dinp + b * T * C + t * C;\n            const float mean_bt = mean[b * T + t];\n            const float rstd_bt = rstd[b * T + t];\n\n            // first: two reduce operations\n            float dnorm_mean = 0.0f;\n            float dnorm_norm_mean = 0.0f;\n            for (int i = 0; i < C; i++) {\n                float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;\n                float dnorm_i = weight[i] * dout_bt[i];\n                dnorm_mean += dnorm_i;\n                dnorm_norm_mean += dnorm_i * norm_bti;\n            }\n            dnorm_mean = dnorm_mean / C;\n            dnorm_norm_mean = dnorm_norm_mean / C;\n\n            // now iterate again and accumulate all the gradients\n            for (int i = 0; i < C; i++) {\n                float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;\n                float dnorm_i = weight[i] * dout_bt[i];\n                // gradient contribution to bias\n                dbias[i] += dout_bt[i];\n                // gradient contribution to weight\n                dweight[i] += norm_bti * dout_bt[i];\n                // gradient contribution to input\n                float dval = 0.0f;\n                dval += dnorm_i; // term 1\n                dval -= dnorm_mean; // term 2\n                dval -= norm_bti * dnorm_norm_mean; // term 3\n                dval *= rstd_bt; // final scale\n                dinp_bt[i] += dval;\n            }\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n\n// GPU helper functions for atomicAdd on smaller than 32-bit types\n#ifdef ENABLE_BF16\n__device__ void atomicAddX(__nv_bfloat16* addr, __nv_bfloat16 val) {\n    uintptr_t ptr_val = reinterpret_cast<uintptr_t>(addr);\n    __nv_bfloat162* ptr_bf16 = reinterpret_cast<__nv_bfloat162*>(ptr_val & ~uintptr_t(0x3));\n\n    // Prepare the value to add, setting the other half to zero\n    __nv_bfloat162 add_val = (ptr_val & 0x3) ? __halves2bfloat162(__ushort_as_bfloat16(0), val)\n                                             : __halves2bfloat162(val, __ushort_as_bfloat16(0));\n    atomicAdd(ptr_bf16, add_val);\n}\n#endif\n#ifdef ENABLE_FP16\n__device__ void atomicAddX(half* addr, half val) {\n    uintptr_t ptr_val = reinterpret_cast<uintptr_t>(addr);\n    half2* ptr_fp16 = reinterpret_cast<half2*>(ptr_val & ~uintptr_t(0x3));\n\n    // Prepare the value to add, setting the other half to zero\n    half2 add_val = (ptr_val & 0x3) ? __halves2half2(__ushort_as_half(0), val)\n                                    : __halves2half2(val, __ushort_as_half(0));\n    atomicAdd(ptr_fp16, add_val);\n}\n#endif\n__device__ void atomicAddX(float* addr, float val) {\n    atomicAdd(addr, val);\n}\n\n// super naive kernel that just parallelizes over B,T and loops over C\n__global__ void layernorm_backward_kernel1(float* dinp, float* dweight, float* dbias,\n                        const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd,\n                        int B, int T, int C) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx >= B*T) return;\n    int b = idx / T;\n    int t = idx % T;\n\n    const float* dout_bt = dout + b * T * C + t * C;\n    const float* inp_bt = inp + b * T * C + t * C;\n    float* dinp_bt = dinp + b * T * C + t * C;\n    const float mean_bt = mean[b * T + t];\n    const float rstd_bt = rstd[b * T + t];\n\n    // first: two reduce operations\n    float dnorm_mean = 0.0f;\n    float dnorm_norm_mean = 0.0f;\n    for (int i = 0; i < C; i++) {\n        float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;\n        float dnorm_i = weight[i] * dout_bt[i];\n        dnorm_mean += dnorm_i;\n        dnorm_norm_mean += dnorm_i * norm_bti;\n    }\n    dnorm_mean = dnorm_mean / C;\n    dnorm_norm_mean = dnorm_norm_mean / C;\n\n    // now iterate again and accumulate all the gradients\n    for (int i = 0; i < C; i++) {\n        float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;\n        float dnorm_i = weight[i] * dout_bt[i];\n        // gradient contribution to bias\n        atomicAdd(&dbias[i], dout_bt[i]);\n        // gradient contribution to weight\n        atomicAdd(&dweight[i], norm_bti * dout_bt[i]);\n        // gradient contribution to input\n        float dval = 0.0f;\n        dval += dnorm_i; // term 1\n        dval -= dnorm_mean; // term 2\n        dval -= norm_bti * dnorm_norm_mean; // term 3\n        dval *= rstd_bt; // final scale\n        dinp_bt[i] += dval;\n    }\n}\n\n// uses shared memory instead for the reduces\ntemplate <typename Tdinp, typename Tparams, typename Tdout, typename Trest>\n__global__ void layernorm_backward_kernel2(Tdinp* dinp, Tparams* dweight, Tparams* dbias,\n                        const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,\n                        int B, int T, int C, float* dweight_tmp, float* dbias_tmp) {\n    extern __shared__ float shared[]; // size = 2 * C\n\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n    int N = B * T;\n    if(idx >= N) { return; } // thread guards\n\n    int b = idx / T;\n    int t = idx % T;\n\n    const Tdout* dout_bt = dout + b * T * C + t * C;\n    const Trest* inp_bt = inp + b * T * C + t * C;\n    Tdinp* dinp_bt = dinp + b * T * C + t * C;\n    const float mean_bt = (float)mean[b * T + t];\n    const float rstd_bt = (float)rstd[b * T + t];\n\n    // the first half of shared memory is bias, second is weight\n    float* dbias_shared = shared;\n    float* dweight_shared = shared + C;\n\n    // init shared memory to zero\n    #pragma unroll\n    for(int i = threadIdx.x; i < C; i+= blockDim.x){\n       dbias_shared[i] = 0.0f;\n       dweight_shared[i] = 0.0f;\n    }\n    __syncthreads();\n\n    // first: two reduce operations\n    float dnorm_mean = 0.0f;\n    float dnorm_norm_mean = 0.0f;\n    for (int i = warp.thread_rank(); i < C; i  += warp.size()) {\n        float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt;\n        float dnorm_i = (float)weight[i] * (float)dout_bt[i];\n        dnorm_mean += dnorm_i;\n        dnorm_norm_mean += dnorm_i * norm_bti;\n    }\n    dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus<float>{});\n    dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus<float>{});\n    dnorm_mean = dnorm_mean / C;\n    dnorm_norm_mean = dnorm_norm_mean / C;\n\n    // now iterate again and accumulate all the gradients\n    for (int i = warp.thread_rank(); i < C; i += warp.size()) {\n        float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt;\n        float dnorm_i = (float)weight[i] * (float)dout_bt[i];\n        // gradient contribution to bias\n        atomicAdd(&dbias_shared[i], (float)dout_bt[i]);\n        // gradient contribution to weight\n        atomicAdd(&dweight_shared[i], norm_bti * (float)dout_bt[i]);\n        // gradient contribution to input\n        float dval = 0.0f;\n        dval += dnorm_i; // term 1\n        dval -= dnorm_mean; // term 2\n        dval -= norm_bti * dnorm_norm_mean; // term 3\n        dval *= rstd_bt; // final scale\n        dinp_bt[i] = (Tdinp)((float)dinp_bt[i] + dval);\n    }\n    __syncthreads();\n\n    // write to global memory\n    for(int i = threadIdx.x; i < C; i+= blockDim.x) {\n        atomicAdd(&dbias_tmp[i], dbias_shared[i]);\n        atomicAdd(&dweight_tmp[i], dweight_shared[i]);\n    }\n}\n\ntemplate <typename Tparams>\n__global__ void copy_to_dweight_dbias(int C, Tparams* dbias, Tparams* dweight, float* dbias_tmp, float* dweight_tmp) {\n    for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < C; i += blockDim.x * gridDim.x) {\n        dbias[i] = (Tparams)dbias_tmp[i];\n        dweight[i] = (Tparams)dweight_tmp[i];\n    }\n}\n\n// kernel2 is 1 threadblock for all Cs on 32 BTs (assuming threadblock size of 1024 threads = 32 warps)\n// To minimise the amount of atomicAdds, we will aim for 1 threadblock per SM, processing (total BTs / threadblocks) BTs\ntemplate <typename Tdinp, typename Tparams, typename Tdout, typename Trest>\n__global__ void layernorm_backward_kernel3(Tdinp* dinp, Tparams* dweight, Tparams* dbias,\n                        const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,\n                        int B, int T, int C) {\n    extern __shared__ float shared[]; // size = 2 * C\n\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int base_idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n\n    // the first half of shared memory is bias, second is weight\n    float* dbias_shared = shared;\n    float* dweight_shared = shared + C;\n\n    // init shared memory to zero\n    #pragma unroll 4\n    for(int i = threadIdx.x; i < C; i+= blockDim.x){\n       dbias_shared[i] = 0.0f;\n       dweight_shared[i] = 0.0f;\n    }\n    __syncthreads();\n\n    int warps_in_grid = gridDim.x * warp.meta_group_size();\n    for (int idx = base_idx; idx < B * T; idx += warps_in_grid) {\n        int b = idx / T;\n        int t = idx % T;\n\n        const Tdout* dout_bt = dout + b * T * C + t * C;\n        const Trest* inp_bt = inp + b * T * C + t * C;\n        Tdinp* dinp_bt = dinp + b * T * C + t * C;\n        const float mean_bt = (float)mean[b * T + t];\n        const float rstd_bt = (float)rstd[b * T + t];\n\n        // first: two reduce operations\n        float dnorm_mean = 0.0f;\n        float dnorm_norm_mean = 0.0f;\n        for (int i = warp.thread_rank(); i < C; i  += warp.size()) {\n            float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt;\n            float dnorm_i = (float)weight[i] * (float)dout_bt[i];\n            dnorm_mean += dnorm_i;\n            dnorm_norm_mean += dnorm_i * norm_bti;\n        }\n        dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus<float>{});\n        dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus<float>{});\n        dnorm_mean = dnorm_mean / C;\n        dnorm_norm_mean = dnorm_norm_mean / C;\n\n        // now iterate again and accumulate all the gradients\n        for (int i = warp.thread_rank(); i < C; i += warp.size()) {\n            float dout_i = (float)__ldcs(&dout_bt[i]);\n            float norm_bti = ((float)__ldcs(&inp_bt[i]) - mean_bt) * rstd_bt;\n            float dnorm_i = (float)weight[i] * dout_i;\n            // gradient contribution to bias\n            atomicAdd(&dbias_shared[i], dout_i);\n            // gradient contribution to weight\n            atomicAdd(&dweight_shared[i], norm_bti * dout_i);\n            // gradient contribution to input\n            float dval = 0.0f;\n            dval += dnorm_i; // term 1\n            dval -= dnorm_mean; // term 2\n            dval -= norm_bti * dnorm_norm_mean; // term 3\n            dval *= rstd_bt; // final scale\n            dinp_bt[i] = (Tdinp)((float)dinp_bt[i] + dval);\n        }\n    }\n    __syncthreads();\n\n    for(int i = threadIdx.x; i < C; i+= blockDim.x) {\n        atomicAddX(&dbias[i], (Tparams)dbias_shared[i]);\n        atomicAddX(&dweight[i], (Tparams)dweight_shared[i]);\n    }\n}\n\n// atomicCAS version of kernel3\ntemplate <typename Tdinp, typename Tparams, typename Tdout, typename Trest>\n__global__ void layernorm_backward_kernel4(Tdinp* dinp, Tparams* dweight, Tparams* dbias,\n                        const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,\n                        int B, int T, int C) {\n    extern __shared__ float shared[]; // size = 2 * C\n\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int base_idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n\n    // the first half of shared memory is bias, second is weight\n    float* dbias_shared = shared;\n    float* dweight_shared = shared + C;\n\n    // init shared memory to zero\n    #pragma unroll 4\n    for(int i = threadIdx.x; i < C; i+= blockDim.x){\n       dbias_shared[i] = 0.0f;\n       dweight_shared[i] = 0.0f;\n    }\n    __syncthreads();\n\n    int warps_in_grid = gridDim.x * warp.meta_group_size();\n    for (int idx = base_idx; idx < B * T; idx += warps_in_grid) {\n        int b = idx / T;\n        int t = idx % T;\n\n        const Tdout* dout_bt = dout + b * T * C + t * C;\n        const Trest* inp_bt = inp + b * T * C + t * C;\n        Tdinp* dinp_bt = dinp + b * T * C + t * C;\n        const float mean_bt = (float)mean[b * T + t];\n        const float rstd_bt = (float)rstd[b * T + t];\n\n        // first: two reduce operations\n        float dnorm_mean = 0.0f;\n        float dnorm_norm_mean = 0.0f;\n        for (int i = warp.thread_rank(); i < C; i  += warp.size()) {\n            float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt;\n            float dnorm_i = (float)weight[i] * (float)dout_bt[i];\n            dnorm_mean += dnorm_i;\n            dnorm_norm_mean += dnorm_i * norm_bti;\n        }\n        dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus<float>{});\n        dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus<float>{});\n        dnorm_mean = dnorm_mean / C;\n        dnorm_norm_mean = dnorm_norm_mean / C;\n\n        // now iterate again and accumulate all the gradients\n        for (int i = warp.thread_rank(); i < C; i += warp.size()) {\n            float dout_i = (float)__ldcs(&dout_bt[i]);\n            float norm_bti = ((float)__ldcs(&inp_bt[i]) - mean_bt) * rstd_bt;\n            float dnorm_i = (float)weight[i] * dout_i;\n            // gradient contribution to bias\n            atomicAdd(&dbias_shared[i], dout_i);\n            // gradient contribution to weight\n            atomicAdd(&dweight_shared[i], norm_bti * dout_i);\n            // gradient contribution to input\n            float dval = 0.0f;\n            dval += dnorm_i; // term 1\n            dval -= dnorm_mean; // term 2\n            dval -= norm_bti * dnorm_norm_mean; // term 3\n            dval *= rstd_bt; // final scale\n            dinp_bt[i] = (Tdinp)((float)dinp_bt[i] + dval);\n        }\n    }\n    __syncthreads();\n\n    __nv_bfloat162* dbiasVec2 = reinterpret_cast<__nv_bfloat162*>(dbias);\n    __nv_bfloat162* dweightVec2 = reinterpret_cast<__nv_bfloat162*>(dweight);\n\n    // write to global memory\n    for(int i = threadIdx.x; i < C/2; i+= blockDim.x) {\n        __nv_bfloat162 add_dbias = __halves2bfloat162((__nv_bfloat16)dbias_shared[i*2], (__nv_bfloat16)dbias_shared[i*2+1]);\n        __nv_bfloat162 add_dweight = __halves2bfloat162((__nv_bfloat16)dweight_shared[i*2], (__nv_bfloat16)dweight_shared[i*2+1]);\n\n        // Get the current value from L2 cache\n        __nv_bfloat162 current_dbias = __ldcg(&dbiasVec2[i]);\n        __nv_bfloat162 current_dweight = __ldcg(&dweightVec2[i]);\n\n        // Add the two values\n        __nv_bfloat162 new_dbias = add_dbias + current_dbias;\n        __nv_bfloat162 new_dweight = add_dweight + current_dweight;\n\n        // Write the result back to L2 cache using 32-bit integer atomic compare and exchange\n        unsigned int current_dbias32b = *reinterpret_cast<unsigned int*>(&current_dbias);\n        unsigned int current_dweight32b = *reinterpret_cast<unsigned int*>(&current_dweight);\n\n        unsigned int new_dbias32b = *reinterpret_cast<unsigned int*>(&new_dbias);\n        unsigned int new_dweight32b = *reinterpret_cast<unsigned int*>(&new_dweight);\n\n        unsigned int old_dbias32b = atomicCAS((unsigned int*)&dbiasVec2[i], current_dbias32b, new_dbias32b);\n        unsigned int old_dweight32b = atomicCAS((unsigned int*)&dweightVec2[i], current_dweight32b, new_dweight32b);\n\n        // If the value has changed between read and atomic, we need to try again\n        while (old_dbias32b != current_dbias32b) {\n            current_dbias32b = old_dbias32b;\n            new_dbias = *reinterpret_cast<__nv_bfloat162*>(&current_dbias32b) + add_dbias;\n            new_dbias32b = *reinterpret_cast<unsigned int*>(&new_dbias);\n            old_dbias32b = atomicCAS((unsigned int*)&dbiasVec2[i], current_dbias32b, new_dbias32b);\n        }\n\n        while (old_dweight32b != current_dweight32b) {\n            current_dweight32b = old_dweight32b;\n            new_dweight = *reinterpret_cast<__nv_bfloat162*>(&current_dweight32b) + add_dweight;\n            new_dweight32b = *reinterpret_cast<unsigned int*>(&new_dweight);\n            old_dweight32b = atomicCAS((unsigned int*)&dweightVec2[i], current_dweight32b, new_dweight32b);\n        }\n    }\n}\n\n// FP32 scratchpad per threadgroup, zero atomics except atomicAdd on unsigned int for the flag (based on kernel3)\ntemplate <typename Tdinp, typename Tparams, typename Tdout, typename Trest>\n__global__ void layernorm_backward_kernel5(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch,\n                        const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,\n                        int B, int T, int C) {\n    extern __shared__ float shared[]; // size = 2 * C + 1\n\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int base_idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n\n    // the first half of shared memory is bias, second is weight\n    float* dbias_shared = shared;\n    float* dweight_shared = shared + C;\n\n    // init shared memory to zero\n    #pragma unroll 4\n    for(int i = threadIdx.x; i < C; i+= blockDim.x){\n       dbias_shared[i] = 0.0f;\n       dweight_shared[i] = 0.0f;\n    }\n    unsigned int *tmp_flag = (unsigned int*)(shared + C*2);\n    __syncthreads();\n\n    int warps_in_grid = gridDim.x * warp.meta_group_size();\n    for (int idx = base_idx; idx < B * T; idx += warps_in_grid) {\n        int b = idx / T;\n        int t = idx % T;\n\n        const Tdout* dout_bt = dout + b * T * C + t * C;\n        const Trest* inp_bt = inp + b * T * C + t * C;\n        Tdinp* dinp_bt = dinp + b * T * C + t * C;\n        const float mean_bt = (float)mean[b * T + t];\n        const float rstd_bt = (float)rstd[b * T + t];\n\n        // first: two reduce operations\n        float dnorm_mean = 0.0f;\n        float dnorm_norm_mean = 0.0f;\n        for (int i = warp.thread_rank(); i < C; i  += warp.size()) {\n            float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt;\n            float dnorm_i = (float)weight[i] * (float)dout_bt[i];\n            dnorm_mean += dnorm_i;\n            dnorm_norm_mean += dnorm_i * norm_bti;\n        }\n        dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus<float>{});\n        dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus<float>{});\n        dnorm_mean = dnorm_mean / C;\n        dnorm_norm_mean = dnorm_norm_mean / C;\n\n        // now iterate again and accumulate all the gradients\n        for (int i = warp.thread_rank(); i < C; i += warp.size()) {\n            float dout_i = (float)__ldcs(&dout_bt[i]);\n            float norm_bti = ((float)__ldcs(&inp_bt[i]) - mean_bt) * rstd_bt;\n            float dnorm_i = (float)weight[i] * dout_i;\n            // gradient contribution to bias\n            atomicAdd(&dbias_shared[i], dout_i);\n            // gradient contribution to weight\n            atomicAdd(&dweight_shared[i], norm_bti * dout_i);\n            // gradient contribution to input\n            float dval = 0.0f;\n            dval += dnorm_i; // term 1\n            dval -= dnorm_mean; // term 2\n            dval -= norm_bti * dnorm_norm_mean; // term 3\n            dval *= rstd_bt; // final scale\n            dinp_bt[i] = (Tdinp)((float)dinp_bt[i] + dval);\n        }\n    }\n    __syncthreads();\n\n    float* scratch_dbias = scratch;\n    float* scratch_dweight = scratch + C * gridDim.x;\n    unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C * gridDim.x));\n\n    for(int i = threadIdx.x; i < C; i+= blockDim.x) {\n        scratch_dbias[i + C*blockIdx.x] = dbias_shared[i];\n        scratch_dweight[i + C*blockIdx.x] = dweight_shared[i];\n    }\n    __threadfence();\n    __syncthreads();\n    if (threadIdx.x == 0) {\n        *tmp_flag = atomicAdd(scratchFlag, 1);\n    }\n    __syncthreads();\n    if (*tmp_flag == gridDim.x-1) {\n        // last block to finish, accumulate the scratchpad\n        for (int i = threadIdx.x; i < C; i += blockDim.x) {\n            float dbias_sum = 0.0f;\n            float dweight_sum = 0.0f;\n            #pragma unroll 8\n            for (int j = 0; j < gridDim.x; j++) {\n                dbias_sum += scratch_dbias[i + j*C];\n                dweight_sum += scratch_dweight[i + j*C];\n            }\n            dbias[i] = (Tparams)((float)dbias[i] + dbias_sum);\n            dweight[i] = (Tparams)((float)dweight[i] + dweight_sum);\n        }\n    }\n}\n\n// single FP32 scratchpad shared by all the threadblocks (based on kernels 3 & 5)\ntemplate <typename Tdinp, typename Tparams, typename Tdout, typename Trest>\n__global__ void layernorm_backward_kernel6(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch,\n                        const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,\n                        int B, int T, int C) {\n    extern __shared__ float shared[]; // size = 2 * C + 1\n\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int base_idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n\n    // the first half of shared memory is bias, second is weight\n    float* dbias_shared = shared;\n    float* dweight_shared = shared + C;\n\n    // init shared memory to zero\n    #pragma unroll 4\n    for(int i = threadIdx.x; i < C; i+= blockDim.x){\n       dbias_shared[i] = 0.0f;\n       dweight_shared[i] = 0.0f;\n    }\n    unsigned int *tmp_flag = (unsigned int*)(shared + C*2);\n    __syncthreads();\n\n    int warps_in_grid = gridDim.x * warp.meta_group_size();\n    for (int idx = base_idx; idx < B * T; idx += warps_in_grid) {\n        int b = idx / T;\n        int t = idx % T;\n\n        const Tdout* dout_bt = dout + b * T * C + t * C;\n        const Trest* inp_bt = inp + b * T * C + t * C;\n        Tdinp* dinp_bt = dinp + b * T * C + t * C;\n        const float mean_bt = (float)mean[b * T + t];\n        const float rstd_bt = (float)rstd[b * T + t];\n\n        // first: two reduce operations\n        float dnorm_mean = 0.0f;\n        float dnorm_norm_mean = 0.0f;\n        for (int i = warp.thread_rank(); i < C; i  += warp.size()) {\n            float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt;\n            float dnorm_i = (float)weight[i] * (float)dout_bt[i];\n            dnorm_mean += dnorm_i;\n            dnorm_norm_mean += dnorm_i * norm_bti;\n        }\n        dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus<float>{});\n        dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus<float>{});\n        dnorm_mean = dnorm_mean / C;\n        dnorm_norm_mean = dnorm_norm_mean / C;\n\n        // now iterate again and accumulate all the gradients\n        for (int i = warp.thread_rank(); i < C; i += warp.size()) {\n            float dout_i = (float)__ldcs(&dout_bt[i]);\n            float norm_bti = ((float)__ldcs(&inp_bt[i]) - mean_bt) * rstd_bt;\n            float dnorm_i = (float)weight[i] * dout_i;\n            // gradient contribution to bias\n            atomicAdd(&dbias_shared[i], dout_i);\n            // gradient contribution to weight\n            atomicAdd(&dweight_shared[i], norm_bti * dout_i);\n            // gradient contribution to input\n            float dval = 0.0f;\n            dval += dnorm_i; // term 1\n            dval -= dnorm_mean; // term 2\n            dval -= norm_bti * dnorm_norm_mean; // term 3\n            dval *= rstd_bt; // final scale\n            dinp_bt[i] = (Tdinp)((float)dinp_bt[i] + dval);\n        }\n    }\n\n    // Accumulate into a FP32 scratchpad\n    // BF16 atomics are potentially much slower... and this is more precise!\n    __syncthreads();\n    float* scratch_dbias = scratch;\n    float* scratch_dweight = scratch + C;\n    unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C));\n    for(int i = threadIdx.x; i < C; i+= blockDim.x) {\n        atomicAdd(&scratch_dbias[i], dbias_shared[i]);\n        atomicAdd(&scratch_dweight[i], dweight_shared[i]);\n    }\n    __syncthreads();\n    if (threadIdx.x == 0) {\n        *tmp_flag = atomicAdd(scratchFlag, 1);\n    }\n    __syncthreads();\n    if (*tmp_flag == gridDim.x-1) {\n        for(int i = threadIdx.x; i < C; i+= blockDim.x) {\n            // todo - potentially do stochastic rounding here as well\n            dbias[i] = (Tparams)scratch_dbias[i];\n            dweight[i] = (Tparams)scratch_dweight[i];\n        }\n    }\n}\n\n\n// Same as kernel 6 but without cooperative groups or templates\n__global__ void layernorm_backward_kernel7(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch,\n                        const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd,\n                        int B, int T, int C) {\n    extern __shared__ float shared[]; // size = 2 * C + 1\n    int warpId = threadIdx.x / warpSize; // warp index within a block\n    int warpsInBlock = blockDim.x / warpSize;\n    int base_idx = blockIdx.x * warpsInBlock + warpId;\n    int warpThreadIdx = threadIdx.x % warpSize; // Thread index within the warp\n    int warps_in_grid = gridDim.x * warpsInBlock;\n\n    // the first half of shared memory is bias, second is weight\n    float* dbias_shared = shared;\n    float* dweight_shared = shared + C;\n\n    // init shared memory to zero\n    #pragma unroll 4\n    for(int i = threadIdx.x; i < C; i+= blockDim.x){\n       dbias_shared[i] = 0.0f;\n       dweight_shared[i] = 0.0f;\n    }\n    unsigned int *tmp_flag = (unsigned int*)(shared + C*2);\n    __syncthreads();\n\n    for (int idx = base_idx; idx < B * T; idx += warps_in_grid) {\n        int b = idx / T;\n        int t = idx % T;\n\n        const floatX* dout_bt = dout + b * T * C + t * C;\n        const floatX* inp_bt = inp + b * T * C + t * C;\n        floatX* dinp_bt = dinp + b * T * C + t * C;\n        const float mean_bt = (float)mean[b * T + t];\n        const float rstd_bt = (float)rstd[b * T + t];\n\n        // first: two reduce operations\n        float dnorm_mean = 0.0f;\n        float dnorm_norm_mean = 0.0f;\n        for (int i = warpThreadIdx; i < C; i  += warpSize) {\n            float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt;\n            float dnorm_i = (float)weight[i] * (float)dout_bt[i];\n            dnorm_mean += dnorm_i;\n            dnorm_norm_mean += dnorm_i * norm_bti;\n        }\n        dnorm_mean = warpReduceSum(dnorm_mean);\n        dnorm_norm_mean = warpReduceSum(dnorm_norm_mean);\n\n        dnorm_mean = dnorm_mean / C;\n        dnorm_norm_mean = dnorm_norm_mean / C;\n\n        // now iterate again and accumulate all the gradients\n        for (int i = warpThreadIdx; i < C; i += warpSize) {\n            float dout_i = (float)__ldcs(&dout_bt[i]);\n            float norm_bti = ((float)__ldcs(&inp_bt[i]) - mean_bt) * rstd_bt;\n            float dnorm_i = (float)weight[i] * dout_i;\n            // gradient contribution to bias\n            atomicAdd(&dbias_shared[i], dout_i);\n            // gradient contribution to weight\n            atomicAdd(&dweight_shared[i], norm_bti * dout_i);\n            // gradient contribution to input\n            float dval = 0.0f;\n            dval += dnorm_i; // term 1\n            dval -= dnorm_mean; // term 2\n            dval -= norm_bti * dnorm_norm_mean; // term 3\n            dval *= rstd_bt; // final scale\n            dinp_bt[i] = (floatX)((float)dinp_bt[i] + dval);\n        }\n    }\n\n    // Accumulate into a FP32 scratchpad\n    // BF16 atomics are potentially much slower... and this is more precise!\n    __syncthreads();\n    float* scratch_dbias = scratch;\n    float* scratch_dweight = scratch + C;\n    unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C));\n    for(int i = threadIdx.x; i < C; i+= blockDim.x) {\n        atomicAdd(&scratch_dbias[i], dbias_shared[i]);\n        atomicAdd(&scratch_dweight[i], dweight_shared[i]);\n    }\n    __syncthreads();\n    if (threadIdx.x == 0) {\n        *tmp_flag = atomicAdd(scratchFlag, 1);\n    }\n    __syncthreads();\n    if (*tmp_flag == gridDim.x-1) {\n        for(int i = threadIdx.x; i < C; i+= blockDim.x) {\n            // todo - potentially do stochastic rounding here as well\n            dbias[i] = (floatX)scratch_dbias[i];\n            dweight[i] = (floatX)scratch_dweight[i];\n        }\n    }\n}\n\n__global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS)\n                layernorm_backward_kernel8(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch,\n                                            const floatX* dout, const floatX* inp, const floatX* weight,\n                                            const floatX* mean, const floatX* rstd,\n                                            int B, int T, int C) {\n    extern __shared__ float shared[]; // size = 2 * C + 1\n    int warpId = threadIdx.x / warpSize; // warp index within a block\n    int warpsInBlock = blockDim.x / warpSize; //number of warps in block\n    int baseIdx = blockIdx.x * warpsInBlock + warpId;\n    int warpThreadIdx = threadIdx.x % warpSize; // Thread index within the warp\n    int warpsInGrid = gridDim.x * warpsInBlock;\n    int C_per_iteration = warpSize * x128::size;\n    int iterations_C = C / C_per_iteration;\n\n    // the first half of shared memory is bias, second is weight\n    float* dbias_shared = shared;\n    float* dweight_shared = shared + C;\n\n    // init shared memory to zero\n    for(int i = threadIdx.x; i < C; i+= blockDim.x){\n       dbias_shared[i] = 0.0f;\n       dweight_shared[i] = 0.0f;\n    }\n    unsigned int *tmp_flag = (unsigned int*)(shared + C*2);\n    __syncthreads();\n\n    for (int idx = baseIdx; idx < B * T; idx += warpsInGrid) {\n        int b = idx / T;\n        int t = idx % T;\n\n        const floatX* dout_bt = dout + b * T * C + t * C;\n        const floatX* inp_bt = inp + b * T * C + t * C;\n        floatX* dinp_bt = dinp + b * T * C + t * C;\n        const float mean_bt = (float)mean[b * T + t];\n        const float rstd_bt = (float)rstd[b * T + t];\n\n        // first: two reduce operations\n        float dnorm_mean = 0.0f;\n        float dnorm_norm_mean = 0.0f;\n        for (int i = warpThreadIdx * x128::size; i < C; i += warpSize * x128::size) {\n            x128 dout128_i   = load128(dout_bt + i);\n            x128 inp128_i    = load128(inp_bt  + i);\n            x128 weight128_i = load128(weight  + i);\n            for (int k = 0; k < x128::size; k++) {\n                float norm_bti = ((float)inp128_i[k] - mean_bt) * rstd_bt;\n                float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k];\n                dnorm_mean += dnorm_i;\n                dnorm_norm_mean += dnorm_i * norm_bti;\n            }\n        }\n        dnorm_mean = warpReduceSum(dnorm_mean) / C;\n        dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C;\n\n        // now iterate again and accumulate all the gradients\n        // unfortunately we cannot use the same index for x128 arrays and shared memory\n        // as atomics can only be 32-bit rather than 128-bit (at least pre-SM90/Hopper)\n        // so this would result in an 8-way bank conflict, and kill performance\n        // so instead, we use a shared memory friendly index, and reorder before the final write\n        for (int i = 0; i < iterations_C; i++) {\n            int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration);\n            int shared_index = warpThreadIdx + (i * C_per_iteration);\n            x128 dout128   = load128cs(dout_bt + global_index);\n            x128 inp128    = load128cs(inp_bt  + global_index);\n            x128 dinp128   = load128(dinp_bt   + global_index);\n            x128 weight128 = load128(weight    + global_index);\n\n            for (int x = 0; x < x128::size; x++) {\n                float dout_i = (float)dout128[x];\n                float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt;\n                float dnorm_i = (float)weight128[x] * dout_i;\n                // gradient contribution to bias (using shared memory friendly index)\n                atomicAdd(&dbias_shared[shared_index + x*warpSize], dout_i);\n                // gradient contribution to weight (using shared memory friendly index)\n                atomicAdd(&dweight_shared[shared_index + x*warpSize], norm_bti * dout_i);\n                // gradient contribution to input\n                float dval = 0.0f;\n                dval += dnorm_i; // term 1\n                dval -= dnorm_mean; // term 2\n                dval -= norm_bti * dnorm_norm_mean; // term 3\n                dval *= rstd_bt; // final scale\n                dinp128[x] = (floatX)((float)dinp128[x] + dval);\n            }\n            // cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing\n            store128cg(dinp_bt + global_index, dinp128);\n        }\n    }\n    // Accumulate into a FP32 scratchpad\n    // BF16 atomics are potentially much slower... and this is more precise!\n    // todo - could potentially avoid the extra copy if floatX is FP32, fairly negligible though\n    __syncthreads();\n    float* scratch_dbias = scratch;\n    float* scratch_dweight = scratch + C;\n    unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C));\n    for(int i = threadIdx.x; i < C; i+= blockDim.x) {\n        // global atomics in the same \"shared memory banking friendly\" order\n        atomicAdd(&scratch_dbias[i], dbias_shared[i]);\n        atomicAdd(&scratch_dweight[i], dweight_shared[i]);\n    }\n    __syncthreads();\n    if (threadIdx.x == 0) {\n        *tmp_flag = atomicInc(scratchFlag, gridDim.x);\n    }\n    __syncthreads();\n    if (*tmp_flag == gridDim.x-1) {\n        for (int i = warpId; i < iterations_C; i += warpsInBlock) {\n            // reorder from atomic/shared memory-friendly index to real global memory index\n            // and convert from float/FP32 to floatX/BF16 for the final write\n            int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration);\n            int shared_index = warpThreadIdx + (i * C_per_iteration);\n\n            x128 dbias128 = load128(dbias + global_index);\n            x128 dweight128 = load128(dweight + global_index);\n            for (int x = 0; x < x128::size; x++) {\n                float s_db = scratch_dbias[shared_index + x*warpSize];\n                float s_dw = scratch_dweight[shared_index + x*warpSize];\n                dbias128[x] = (floatX)(s_db + (float)dbias128[x]);\n                dweight128[x] = (floatX)(s_dw + (float)dweight128[x]);\n            }\n            store128(dbias + global_index, dbias128);\n            store128(dweight + global_index, dweight128);\n        }\n    }\n}\n\n__global__ void layernorm_backward_kernel9(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch,\n                                            const floatX* dout, const floatX* inp, const floatX* weight,\n                                            const floatX* mean, const floatX* rstd,\n                                            int B, int T, int C) {\n    if(C % (32 * x128::size) != 0) {\n        if(threadIdx.x == 0 && blockIdx.x == 0) {\n            printf(\"Number of channels is not a multiple of 32 * x128::size\");\n        }\n        __trap();       // prefer to crash here than run into a deadlock later on\n    }\n    int BLOCK_SIZE = blockDim.x;\n    int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block\n    extern __shared__ float shared[]; // size = 2 * C + 1\n\n    int warpId = threadIdx.x / WARP_SIZE; // warp index within a block\n    int baseIdx = blockIdx.x * warpsInBlock + warpId;\n    int warpThreadIdx = threadIdx.x % WARP_SIZE; // Thread index within the warp\n    int warpsInGrid = gridDim.x * warpsInBlock;\n    int C_per_iteration = WARP_SIZE * x128::size;\n    int iterations_C = ceil_div(C, C_per_iteration) + 2;\n\n    // the first half of shared memory is bias, second is weight\n    float* dbias_shared = shared;\n    float* dweight_shared = shared + C;\n    float* dbias_tmp_shared = shared + 2 * C;\n    float* dweight_tmp_shared = shared + 2 * C + BLOCK_SIZE;\n\n    // init shared memory to zero\n    for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE){\n       dbias_shared[i] = 0.0f;\n       dweight_shared[i] = 0.0f;\n    }\n    unsigned int *tmp_flag = (unsigned int*)(shared + 2*C + 2*BLOCK_SIZE);\n    __syncthreads();\n\n    for (int idx = baseIdx; idx < B * T; idx += warpsInGrid) {\n        int b = idx / T;\n        int t = idx % T;\n\n        const floatX* dout_bt = dout + b * T * C + t * C;\n        const floatX* inp_bt = inp + b * T * C + t * C;\n        floatX* dinp_bt = dinp + b * T * C + t * C;\n        const float mean_bt = (float)mean[b * T + t];\n        const float rstd_bt = (float)rstd[b * T + t];\n\n        // first: two reduce operations\n        float dnorm_mean = 0.0f;\n        float dnorm_norm_mean = 0.0f;\n        for (int i = warpThreadIdx * x128::size; i < C; i += WARP_SIZE * x128::size) {\n            x128 dout128_i   = load128(dout_bt + i);\n            x128 inp128_i    = load128(inp_bt  + i);\n            x128 weight128_i = load128(weight  + i);\n            for (int k = 0; k < x128::size; k++) {\n                float norm_bti = ((float)inp128_i[k] - mean_bt) * rstd_bt;\n                float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k];\n                dnorm_mean += dnorm_i;\n                dnorm_norm_mean += dnorm_i * norm_bti;\n            }\n        }\n        dnorm_mean = warpReduceSum(dnorm_mean) / C;\n        dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C;\n\n        // now iterate again and accumulate all the gradients\n        // unfortunately we cannot use the same index for x128 arrays and shared memory\n        // as atomics can only be 32-bit rather than 128-bit (at least pre-SM90/Hopper)\n        // so this would result in an 8-way bank conflict, and kill performance\n        // so instead, we use a shared memory friendly index, and reorder before the final write\n        for (int i = 0; i < iterations_C; i++) {\n            int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration);\n            int shared_index = warpThreadIdx + (i * C_per_iteration);\n            if (global_index >= C) {\n                break;\n            }\n\n            x128 dout128   = load128cs(dout_bt + global_index);\n            x128 inp128    = load128cs(inp_bt  + global_index);\n            x128 dinp128   = load128(dinp_bt   + global_index);\n            x128 weight128 = load128(weight    + global_index);\n\n            for (int x = 0; x < x128::size; x++) {\n                float dout_i = (float)dout128[x];\n                float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt;\n                float dnorm_i = (float)weight128[x] * dout_i;\n\n                // sum up the gradients for bias and weight across the entire block\n                // this is basically a reduction (but only inter-warp, not intra-warp)\n                // doing it this way allows us to avoid using atomics while using many warps\n                if (warpId != 0) {\n                    dbias_tmp_shared[threadIdx.x] = dout_i;\n                    dweight_tmp_shared[threadIdx.x] = norm_bti * dout_i;\n                }\n                __syncthreads();\n                if (warpId == 0) {\n                    float dbias_tmp = dout_i;\n                    float dweight_tmp = norm_bti * dout_i;\n                    for (int j = 1; j < warpsInBlock; j++) {\n                        dbias_tmp += dbias_tmp_shared[threadIdx.x + j * WARP_SIZE];\n                        dweight_tmp += dweight_tmp_shared[threadIdx.x + j * WARP_SIZE];\n                    }\n                    // gradient contribution to bias (using shared memory friendly index)\n                    dbias_shared[shared_index + x*WARP_SIZE] += dbias_tmp;\n                    // gradient contribution to weight (using shared memory friendly index)\n                    dweight_shared[shared_index + x*WARP_SIZE] += dweight_tmp;\n                }\n                __syncthreads();\n\n                // gradient contribution to input\n                float dval = 0.0f;\n                dval += dnorm_i; // term 1\n                dval -= dnorm_mean; // term 2\n                dval -= norm_bti * dnorm_norm_mean; // term 3\n                dval *= rstd_bt; // final scale\n                dinp128[x] = (floatX)((float)dinp128[x] + dval);\n            }\n            // cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing\n            store128cg(dinp_bt + global_index, dinp128);\n        }\n    }\n    __syncthreads();\n    // Each block writes its partial sum to global memory\n    // The last block to finish becomes responsible for summing up all the partial sums\n    // This is done by atomically incrementing a flag (cleared to 0 before launching the kernel)\n    unsigned int* scratchFlag = (unsigned int*)(scratch);\n    // Increment scratch pointer by a full cacheline so that everything remains cacheline aligned\n    scratch += 32;\n    float* scratch_dbias = scratch;\n    float* scratch_dweight = scratch + C;\n    for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE) {\n        // Write to global memory in the same \"shared memory banking friendly\" order\n        scratch_dbias[i + 2*C*blockIdx.x] = dbias_shared[i];\n        scratch_dweight[i + 2*C*blockIdx.x] = dweight_shared[i];\n    }\n    __syncthreads();\n    if (threadIdx.x == 0) {\n        *tmp_flag = atomicInc(scratchFlag, gridDim.x);\n    }\n    __syncthreads();\n    if (*tmp_flag == gridDim.x-1) {\n        // Reduction of the partial sums by the final block\n        // todo - there isn't enough parallelism even inside that single SM...\n        // ==> so could maybe split into another kernel with YET ANOTHER level of reduction?!\n        for(int i = threadIdx.x * f128::size; i < C; i+= BLOCK_SIZE * f128::size) {\n            f128 dbias_accum = f128::zeros();\n            f128 dweight_accum = f128::zeros();\n\n            for (int read_block_idx = 0; read_block_idx < gridDim.x; read_block_idx++) {\n                int offset = i + 2*C*read_block_idx;\n                f128 dbias128 = load128(scratch_dbias + offset);\n                f128 dweight128 = load128(scratch_dweight + offset);\n                for(int k = 0; k < f128::size; k++) {\n                    dbias_accum[k] += dbias128[k];\n                    dweight_accum[k] += dweight128[k];\n                }\n            }\n            store128(dbias_shared + i, dbias_accum);\n            store128(dweight_shared + i, dweight_accum);\n        }\n        __syncthreads();\n\n        // reorder from atomic/shared memory-friendly index to real global memory index\n        // and convert from float/FP32 to floatX/BF16 for the final write\n        // this is separate also because it cannot use as many warps as the above (f128 vs x128)\n        // todo - if we split this code into another kernel, we could maybe do it at the same time?\n        for (int i = warpId; i < iterations_C; i += warpsInBlock) {\n            int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration);\n            int shared_index = warpThreadIdx + (i * C_per_iteration);\n            if (global_index >= C) {\n                break;\n            }\n\n            x128 dbias128 = load128(dbias + global_index);\n            x128 dweight128 = load128(dweight + global_index);\n            for (int x = 0; x < x128::size; x++) {\n                float s_db = dbias_shared[shared_index + x*WARP_SIZE];\n                float s_dw = dweight_shared[shared_index + x*WARP_SIZE];\n                dbias128[x] = (floatX)(s_db + (float)dbias128[x]);\n                dweight128[x] = (floatX)(s_dw + (float)dweight128[x]);\n            }\n            store128(dbias + global_index, dbias128);\n            store128(dweight + global_index, dweight128);\n        }\n    }\n}\n\n\n// similar to kernel 9, but uses vectors to access shared memory, which also avoids the bank conflict problems,\n// and makes use require fewer barriers, at the cost of increased shared memory consumption.\n// warning: this kernel is _extremely_ close to getting register spills, so many \"optimizations\" turn out to be unhelpful\n// or need to be implemented in a very specific way.\n__global__ void __launch_bounds__(512, 2)\nlayernorm_backward_kernel10(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch,\n                            const floatX* dout, const floatX* inp, const floatX* weight,\n                            const floatX* mean, const floatX* rstd,\n                            int B, int T, int C) {\n    int BLOCK_SIZE = blockDim.x;\n    int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block\n    extern __shared__ float shared[]; // size = 2 * C + 1\n\n    int warpId = threadIdx.x / WARP_SIZE; // warp index within a block\n    int baseIdx = blockIdx.x * warpsInBlock + warpId;\n    int warpThreadIdx = threadIdx.x % WARP_SIZE; // Thread index within the warp\n    int warpsInGrid = gridDim.x * warpsInBlock;\n    int C_per_iteration = WARP_SIZE * x128::size;\n    int iterations_C = ceil_div(C, C_per_iteration); // + 2;\n\n    // the first half of shared memory is bias, second is weight\n    size_t rounded_C = ceil_div(C, (32 * x128::size)) * (32 * x128::size);\n    float* dbias_shared = shared;\n    float* dweight_shared = shared + rounded_C;\n    // warp zero doesn't actually write to the _tmp_shared memory locations, so we don't need to reserve memory\n    // the obvious solution is to change the addressing below to use (threadId.x-32) as offset, but that causes\n    // register spills, so instead we mess with the base pointer here, which doesn't increase register usage.\n    float* dbias_tmp_shared = shared + 2 * rounded_C - WARP_SIZE * f128::size;\n    float* dweight_tmp_shared = shared + 2 * rounded_C + f128::size * BLOCK_SIZE - 2 * WARP_SIZE * f128::size;\n\n    // init shared memory to zero\n    for(int i = threadIdx.x * f128::size; i < rounded_C; i += BLOCK_SIZE * f128::size) {\n        store128(dbias_shared + i, f128::zeros());\n        store128(dweight_shared + i, f128::zeros());\n    }\n    __syncthreads();\n\n    for (int bt = baseIdx; bt < B * T; bt += warpsInGrid) {\n        const floatX* dout_bt = dout + bt * C;\n        const floatX* inp_bt = inp +bt * C;\n        floatX* dinp_bt = dinp + bt * C;\n\n        // first: two reduce operations\n        float dnorm_mean = 0.0f;\n        float dnorm_norm_mean = 0.0f;\n        for (int i = warpThreadIdx * x128::size; i < C; i += WARP_SIZE * x128::size) {\n            x128 dout128_i   = load128(dout_bt + i);\n            x128 inp128_i    = load128(inp_bt  + i);\n            x128 weight128_i = load128(weight  + i);\n            for (int k = 0; k < x128::size; k++) {\n                float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k];\n                dnorm_mean += dnorm_i;\n                dnorm_norm_mean += dnorm_i * (float)inp128_i[k];\n            }\n        }\n\n        const float mean_bt = (float)mean[bt];\n        const float rstd_bt = (float)rstd[bt];\n        dnorm_mean = warpReduceSum(dnorm_mean) / C;\n        dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C * rstd_bt - dnorm_mean * mean_bt * rstd_bt;\n\n        for (int c = 0; c < iterations_C; c++) {\n            int global_index = (warpThreadIdx * x128::size) + (c * C_per_iteration);\n\n            x128 dout128   = x128::zeros();\n            x128 inp128    = x128::zeros();\n            x128 dinp128   = x128::zeros();\n            x128 weight128 = x128::zeros();\n\n            if(global_index < C) {\n                dout128 = load128cs(dout_bt + global_index);\n                inp128 = load128cs(inp_bt + global_index);\n                dinp128 = load128(dinp_bt + global_index);\n                weight128 = load128(weight + global_index);\n            }\n\n            for(int o = 0; o < x128::size / f128::size; ++o) {\n                f128 dbias_f;\n                f128 dweight_f;\n                for(int i = 0; i < f128::size; ++i) {\n                    int x = o * f128::size + i;\n                    float dout_i = (float)dout128[x];\n                    float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt;\n                    dbias_f[i] = dout_i;\n                    dweight_f[i] = norm_bti * dout_i;\n\n                    float dval = 0.0f;\n                    dval += (float) weight128[x] * (float)dout128[x]; // term 1\n                    dval -= dnorm_mean; // term 2\n                    dval -= norm_bti * dnorm_norm_mean; // term 3\n                    dval *= rstd_bt; // final scale\n                    dinp128[x] = (floatX) ((float) dinp128[x] + dval);\n                }\n\n                if (warpId != 0) {\n                    store128(dbias_tmp_shared + threadIdx.x * f128::size, dbias_f);\n                    // this seems to generate a 64-bit store, instead of 128-bit.\n                    // however, forcing 128-bit (e.g., using inline ptx), results in register\n                    // spilling and much worse performance, so we'll keep it like this for now\n                    // but ideally, we could reduce the register pressure a little.\n                    store128(dweight_tmp_shared + threadIdx.x * f128::size, dweight_f);\n                }\n                __syncthreads();\n                if (warpId == 0) {\n                    for (int j = 1; j < warpsInBlock; j++) {\n                        f128 dbias_tmp = load128(dbias_tmp_shared + f128::size * (threadIdx.x + j * WARP_SIZE));\n                        f128 dweight_tmp = load128(dweight_tmp_shared + f128::size * (threadIdx.x + j * WARP_SIZE));\n                        for(int i = 0; i < f128::size; ++i) {\n                            dbias_f[i] += dbias_tmp[i];\n                            dweight_f[i] += dweight_tmp[i];\n                        }\n                    }\n                }\n                __syncthreads();\n                if (warpId == 0) {\n                    f128 db_old = load128(dbias_shared + global_index + f128::size * o);\n                    f128 dw_old = load128(dweight_shared + global_index + f128::size * o);\n                    for(int i = 0; i < f128::size; ++i) {\n                        dbias_f[i] += db_old[i];\n                        dweight_f[i] += dw_old[i];\n                    }\n                    store128(dbias_shared + global_index + f128::size * o, dbias_f);\n                    store128(dweight_shared + global_index + f128::size * o, dweight_f);\n                }\n            }\n            if(global_index < C) {\n                // cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing\n                store128cg(dinp_bt + global_index, dinp128);\n            }\n        }\n    }\n    __syncthreads();\n    // Each block writes its partial sum to global memory\n    // The last block to finish becomes responsible for summing up all the partial sums\n    // This is done by atomically incrementing a flag (cleared to 0 before launching the kernel)\n    unsigned int* scratchFlag = (unsigned int*)(scratch);\n    // Increment scratch pointer by a full cacheline so that everything remains cacheline aligned\n    scratch += 32;\n    float* scratch_dbias = scratch;\n    float* scratch_dweight = scratch + C;\n    for(int i = threadIdx.x * f128::size; i < C; i += BLOCK_SIZE * f128::size) {\n        // Write to global memory in the same \"shared memory banking friendly\" order\n        store128(scratch_dbias + i + 2*C*blockIdx.x, load128(dbias_shared + i));\n        store128(scratch_dweight + i + 2*C*blockIdx.x, load128(dweight_shared + i));\n    }\n    __syncthreads();\n    // that portion of shared memory is no longer used, so we can repurpose it for the scratch flag.\n    unsigned int *tmp_flag = (unsigned int*)(shared + 2*rounded_C);\n    if (threadIdx.x == 0) {\n        *tmp_flag = atomicInc(scratchFlag, gridDim.x);\n    }\n    __syncthreads();\n    if (*tmp_flag == gridDim.x-1) {\n        // Reduction of the partial sums by the final block\n        // todo - there isn't enough parallelism even inside that single SM...\n        // ==> so could maybe split into another kernel with YET ANOTHER level of reduction?!\n        for(int i = threadIdx.x * f128::size; i < C; i += BLOCK_SIZE * f128::size) {\n            f128 dbias_accum = f128::zeros();\n            f128 dweight_accum = f128::zeros();\n\n            for (int read_block_idx = 0; read_block_idx < gridDim.x; read_block_idx++) {\n                int offset = i + 2*C*read_block_idx;\n                f128 dbias128 = load128(scratch_dbias + offset);\n                f128 dweight128 = load128(scratch_dweight + offset);\n                for(int k = 0; k < f128::size; k++) {\n                    dbias_accum[k] += dbias128[k];\n                    dweight_accum[k] += dweight128[k];\n                }\n            }\n            store128(dbias_shared + i, dbias_accum);\n            store128(dweight_shared + i, dweight_accum);\n        }\n        __syncthreads();\n\n        // convert from float/FP32 to floatX/BF16 for the final write\n        // this is separate because it cannot use as many warps as the above (f128 vs x128)\n        // todo - if we split this code into another kernel, we could maybe do it at the same time?\n        for (int c = warpId; c < iterations_C; c += warpsInBlock) {\n            int global_index = (warpThreadIdx * x128::size) + (c * C_per_iteration);\n            if (global_index >= C) {\n                break;\n            }\n\n            x128 dbias128 = load128(dbias + global_index);\n            x128 dweight128 = load128(dweight + global_index);\n            for(int o = 0; o < x128::size / f128::size; ++o) {\n                f128 s_db = load128(dbias_shared + global_index + o * f128::size);\n                f128 s_dw = load128(dweight_shared + global_index + o * f128::size);\n                for(int i = 0; i < f128::size; ++i) {\n                    int x = o * f128::size + i;\n                    dbias128[x] = (floatX)(s_db[i] + (float)dbias128[x]);\n                    dweight128[x] = (floatX)(s_dw[i] + (float)dweight128[x]);\n                }\n            }\n            store128(dbias + global_index, dbias128);\n            store128(dweight + global_index, dweight128);\n        }\n    }\n}\n\n\n// ----------------------------------------------------------------------------\n// kernel launchers\n\nvoid layernorm_backward1(float* dinp, float* dweight, float* dbias,\n                        const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd,\n                        int B, int T, int C, const int block_size) {\n    const int N = B * T;\n    const int grid_size = ceil_div(N, block_size);\n    layernorm_backward_kernel1<<<grid_size, block_size>>>(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C);\n}\n\ntemplate <typename Tdinp, typename Tparams, typename Tdout, typename Trest>\nvoid layernorm_backward2(Tdinp* dinp, Tparams* dweight, Tparams* dbias,\n                        const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,\n                        int B, int T, int C, int block_size) {\n    const int N = B * T;\n    const int grid_size = ceil_div(32*N, block_size);\n    size_t shared_mem_size = 2 * C * sizeof(float);\n    float* dweight_tmp;\n    float* dbias_tmp;\n    cudaCheck(cudaMalloc(&dweight_tmp, C * sizeof(float)));\n    cudaCheck(cudaMalloc(&dbias_tmp, C * sizeof(float)));\n    cudaMemset(dweight_tmp, 0, C * sizeof(float));\n    cudaMemset(dbias_tmp, 0, C * sizeof(float));\n    layernorm_backward_kernel2<<<grid_size, block_size, shared_mem_size>>>(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C, dweight_tmp, dbias_tmp);\n    copy_to_dweight_dbias<<<1, 512>>>(C, dweight, dbias, dweight_tmp, dbias_tmp);\n    cudaCheck(cudaFree(dweight_tmp));\n    cudaCheck(cudaFree(dbias_tmp));\n}\n\ntemplate <typename Tdinp, typename Tparams, typename Tdout, typename Trest>\nvoid layernorm_backward3(Tdinp* dinp, Tparams* dweight, Tparams* dbias,\n                        const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,\n                        int B, int T, int C, int block_size) {\n    const int grid_size = (1024/block_size) * cuda_num_SMs;\n    size_t shared_mem_size = 2 * C * sizeof(float);\n    layernorm_backward_kernel3<<<grid_size, block_size, shared_mem_size>>>(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C);\n}\n\ntemplate <typename Tdinp, typename Tparams, typename Tdout, typename Trest>\nvoid layernorm_backward4(Tdinp* dinp, Tparams* dweight, Tparams* dbias,\n                        const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,\n                        int B, int T, int C, int block_size) {\n        const int grid_size = (1024/block_size) * cuda_num_SMs;\n        size_t shared_mem_size = 2 * C * sizeof(float);\n        layernorm_backward_kernel4<<<grid_size, block_size, shared_mem_size>>>(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C);\n}\n\ntemplate <typename Tdinp, typename Tparams, typename Tdout, typename Trest>\nvoid layernorm_backward5(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch,\n                        const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,\n                        int B, int T, int C, int block_size) {\n        const int grid_size = 1 * cuda_num_SMs; // only support 1 block per SM for simplicity, 1024 threads is best anyway\n        size_t shared_mem_size = (2 * C + 1) * sizeof(float);\n        cudaMemset(scratch, 0, (grid_size * 2 * C + 1) * sizeof(float));\n        layernorm_backward_kernel5<<<grid_size, block_size, shared_mem_size>>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C);\n}\n\ntemplate <typename Tdinp, typename Tparams, typename Tdout, typename Trest>\nvoid layernorm_backward6(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch,\n                        const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,\n                        int B, int T, int C, int block_size) {\n        const int grid_size = (1024/block_size) * cuda_num_SMs;\n        size_t shared_mem_size = (2 * C + 1) * sizeof(float);\n\n        // Including this as part of the timing until we can parallelise it\n        // It should fully hide the cost and improve kernel perf by >5% if done in parallel using CUDA streams\n        cudaMemset(scratch, 0, (1 + 2 * C) * sizeof(float));\n\n        layernorm_backward_kernel6<<<grid_size, block_size, shared_mem_size>>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C);\n}\n\ntemplate <typename Tdinp, typename Tparams, typename Tdout, typename Trest>\nvoid layernorm_backward7(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch,\n                        const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,\n                        int B, int T, int C, int block_size) {\n        const int grid_size = (1024/block_size) * cuda_num_SMs;\n        size_t shared_mem_size = (2 * C + 1) * sizeof(float);\n\n        // Including this as part of the timing until we can parallelise it\n        // It should fully hide the cost and improve kernel perf by >5% if done in parallel using CUDA streams\n        cudaMemset(scratch, 0, (1 + 2 * C) * sizeof(float));\n\n        layernorm_backward_kernel7<<<grid_size, block_size, shared_mem_size>>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C);\n}\n\ntemplate <typename Tdinp, typename Tparams, typename Tdout, typename Trest>\nvoid layernorm_backward8(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch,\n                        const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,\n                        int B, int T, int C, int block_size) {\n        const int grid_size = (1024/block_size) * cuda_num_SMs;\n        size_t shared_mem_size = (2 * C + 1) * sizeof(float);\n\n        // Including this as part of the timing until we can parallelise it\n        // It should fully hide the cost and improve kernel perf by >5% if done in parallel using CUDA streams\n        cudaMemset(scratch, 0, (1 + 2 * C) * sizeof(float));\n\n        layernorm_backward_kernel8<<<grid_size, block_size, shared_mem_size>>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C);\n}\n\ntemplate <typename Tdinp, typename Tparams, typename Tdout, typename Trest>\nvoid layernorm_backward9(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch,\n                        const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,\n                        int B, int T, int C, int block_size) {\n\n        assert(C % (32 * x128::size) == 0  && \"Channels must be divisible by (32 * x128::size)\");\n        const int grid_size = (1024/block_size) * cuda_num_SMs; // todo - heuristics for other GPUs?\n        size_t shared_mem_size = (2 * C + 2 * block_size + 1) * sizeof(float);\n\n        cudaMemset(scratch, 0, 1 * sizeof(float)); // just need to memset the flag for this version\n        layernorm_backward_kernel9<<<grid_size, block_size, shared_mem_size>>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C);\n}\n\ntemplate <typename Tdinp, typename Tparams, typename Tdout, typename Trest>\nvoid layernorm_backward10(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch,\n                         const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,\n                         int B, int T, int C, int block_size) {\n        if(block_size == 1024) {\n            block_size = 512;\n        }\n        //assert(C % (32 * x128::size) == 0  && \"Channels must be divisible by (32 * x128::size)\");\n        const int grid_size = (1024/block_size) * cuda_num_SMs; // todo - heuristics for other GPUs?\n        size_t rounded_C = ceil_div(C, (32 * x128::size)) * (32 * x128::size);\n        size_t shared_mem_size = (2 * rounded_C + 2 * (block_size - 32) * f128::size) * sizeof(float);\n\n        cudaCheck(cudaMemset(scratch, 0, 1 * sizeof(float))); // just need to memset the flag for this version\n        layernorm_backward_kernel10<<<grid_size, block_size, shared_mem_size>>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C);\n        cudaCheck(cudaGetLastError());\n}\n\n// kernel version dispatch\nvoid layernorm_backward(int kernel_num,\n                        floatX* dinp, floatX* dweight, floatX* dbias, float* scratch,\n                        const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd,\n                        int B, int T, int C,\n                        const int block_size) {\n    switch (kernel_num) {\n#if !defined(ENABLE_BF16) && !defined(ENABLE_FP16)\n        case 1:\n            layernorm_backward1(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C, block_size);\n            break;\n#endif\n        case 2:\n            layernorm_backward2(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C, block_size);\n            break;\n        case 3:\n            layernorm_backward3(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C, block_size);\n            break;\n#if defined(ENABLE_BF16)\n        case 4:\n            layernorm_backward4(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C, block_size);\n            break;\n#endif\n        case 5:\n            layernorm_backward5(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size);\n            break;\n        case 6:\n            layernorm_backward6(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size);\n            break;\n        case 7:\n            layernorm_backward7(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size);\n            break;\n        case 8:\n            layernorm_backward8(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size);\n            break;\n        case 9:\n            layernorm_backward9(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size);\n            break;\n        case 10:\n            layernorm_backward10(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size);\n            break;\n    default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n    cudaCheck(cudaGetLastError());\n}\n\n// ----------------------------------------------------------------------------\n\nint main(int argc, char **argv) {\n    setup_main();\n\n    int B = 8;\n    int T = 1024;\n    int C = 1600;   // this is the problematic size\n\n    // first do the forward pass in CPU\n    float* out = (float*)malloc(B * T * C * sizeof(float));\n    float* mean = (float*)malloc(B * T * sizeof(float));\n    float* rstd = (float*)malloc(B * T * sizeof(float));\n    float* inp = make_random_float(B * T * C);\n    float* weight = make_random_float(C);\n    float* bias = make_random_float(C);\n    layernorm_forward_cpu(out, mean, rstd, inp, weight, bias, B, T, C);\n\n    // now do the backward pass, again on CPU\n    float *dout = make_random_float(B * T * C);\n    float *dinp = make_zeros_float(B * T * C);\n    float *dweight = make_zeros_float(C);\n    float *dbias = make_zeros_float(C);\n    layernorm_backward_cpu(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C);\n\n    // the above calculations act as the reference\n    // now let's do the same on the GPU\n\n    // read kernel_num from command line\n    int kernel_num = 2;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    // move all the variables we need for backward pass onto the GPU\n    floatX* d_dinp;\n    floatX* d_dweight;\n    floatX* d_dbias;\n    floatX* d_dout;\n    floatX* d_inp;\n    floatX* d_weight;\n    floatX* d_mean;\n    floatX* d_rstd;\n    float* d_scratch;\n    cudaCheck(cudaMalloc(&d_dinp, B * T * C * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_dweight, C * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_dbias, C * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_dout, B * T * C * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_weight, C * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_mean, B * T * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_rstd, B * T * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_scratch, (1024/32) * cuda_num_SMs * (2 * C + 1) * sizeof(float)));\n    // copy over the \"inputs\" to the backward call\n    cudaCheck(memcpy_convert(d_dout, dout, B * T * C));\n    cudaCheck(memcpy_convert(d_inp, inp, B * T * C));\n    cudaCheck(memcpy_convert(d_weight, weight, C));\n    cudaCheck(memcpy_convert(d_mean, mean, B * T));\n    cudaCheck(memcpy_convert(d_rstd, rstd, B * T));\n\n    // launch the kernel\n    // removed 768 because it doesn't work for kernel9 despite being OK in train_gpt2.cu?!\n    int block_sizes[] = {32, 64, 128, 256, 512, /*768,*/ 1024};\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        // init the \"outputs\" of the backward call to zeros\n        cudaCheck(cudaMemset(d_dinp, 0, B * T * C * sizeof(floatX)));\n        cudaCheck(cudaMemset(d_dweight, 0, C * sizeof(floatX)));\n        cudaCheck(cudaMemset(d_dbias, 0, C * sizeof(floatX)));\n\n        layernorm_backward(kernel_num, d_dinp, d_dweight, d_dbias, d_scratch, d_dout, d_inp, d_weight, d_mean, d_rstd,\n                           B, T, C, block_size);\n\n        // check the correctness of the kernel\n        float error_threshold_dinp = sizeof(floatX) == 4 ? 1e-3f : 1e-1f; // allow larger errors for BF16/FP16\n        float error_threshold_dparams = sizeof(floatX) == 4 ? 1e-3f : 5e-1f; // much, much larger...\n        printf(\"Checking correctness...\\n\");\n        printf(\"dinp:\\n\");\n        validate_result(d_dinp, dinp, \"dinp\", B * T * C, error_threshold_dinp);\n        printf(\"dweight:\\n\");\n        validate_result(d_dweight, dweight, \"dweight\", C, error_threshold_dparams);\n        printf(\"dbias:\\n\");\n        validate_result(d_dbias, dbias, \"dbias\", C, error_threshold_dparams);\n\n        printf(\"All results match for block_size=%d.\\n\\n\", block_size);\n    }\n\n    // now time the kernel\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        int repeat_times = 100;\n        float elapsed_time = benchmark_kernel(repeat_times, layernorm_backward, kernel_num,\n                                              d_dinp, d_dweight, d_dbias, d_scratch, d_dout, d_inp, d_weight, d_mean, d_rstd,\n                                              B, T, C, block_size);\n        printf(\"block_size %4d time %.4f ms\\n\", block_size, elapsed_time);\n    }\n\n    // cleanups\n    free(out);\n    free(mean);\n    free(rstd);\n    free(inp);\n    free(weight);\n    free(bias);\n    free(dout);\n    free(dinp);\n    free(dweight);\n    free(dbias);\n    cudaCheck(cudaFree(d_dinp));\n    cudaCheck(cudaFree(d_dweight));\n    cudaCheck(cudaFree(d_dbias));\n    cudaCheck(cudaFree(d_dout));\n    cudaCheck(cudaFree(d_inp));\n    cudaCheck(cudaFree(d_weight));\n    cudaCheck(cudaFree(d_mean));\n    cudaCheck(cudaFree(d_rstd));\n    cudaCheck(cudaFree(d_scratch));\n    return 0;\n}\n"
  },
  {
    "path": "dev/cuda/layernorm_forward.cu",
    "content": "/*\nKernels for layernorm forward pass.\n\nCompile example:\nnvcc -O3 --use_fast_math -lcublas -lcublasLt layernorm_forward.cu -o layernorm_forward\n\nversion 1 is naive port from CPU code to kernel: parallelizes over B,T, loops over C\n./layernorm_forward 1\n\nversion 2 parallelizes over all of B,T,C\n./layernorm_forward 2\n\nversion 3 uses cooperative groups to parallelize over all of B,T,C\n./layernorm_forward 3\n\nversion 4 uses a more clever way to estimate variance, var(x) = mean(x**2) - mean(x)**2\n          (allowing us to do a single pass over x on load)\n./layernorm_forward 4\n\nverstion 5 allocates blocks per row instead of warps per row, same alg as 4 otherwise\n./layernorm_forward 5\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <cuda_runtime.h>\n#include <assert.h>\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n#include \"common.h\"\n// ----------------------------------------------------------------------------\n// CPU code reference\n\n// GPT-2 layernorm forward pass\nvoid layernorm_forward_cpu(float* out, float* mean, float* rstd,\n                       const float* inp, const float* weight, const float* bias,\n                       int B, int T, int C) {\n    float eps = 1e-5f;\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            // seek to the input position inp[b,t,:]\n            const float* x = inp + b * T * C + t * C;\n            // calculate the mean\n            float m = 0.0f;\n            for (int i = 0; i < C; i++) {\n                m += x[i];\n            }\n            m = m/C;\n            // calculate the variance (without any bias correction)\n            float v = 0.0f;\n            for (int i = 0; i < C; i++) {\n                float xshift = x[i] - m;\n                v += xshift * xshift;\n            }\n            v = v/C;\n            // calculate the rstd\n            float s = 1.0f / sqrtf(v + eps);\n            // seek to the output position in out[b,t,:]\n            float* out_bt = out + b * T * C + t * C;\n            for (int i = 0; i < C; i++) {\n                float n = (s * (x[i] - m)); // normalized output\n                float o = n * weight[i] + bias[i]; // scale and shift it\n                out_bt[i] = o; // write\n            }\n            // cache the mean and rstd for the backward pass later\n            mean[b * T + t] = m;\n            rstd[b * T + t] = s;\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n\n// naive drag and drop implementation into kernel, parallelize over B,T, loop over C\n__global__ void layernorm_forward_kernel1(float* out, float* mean, float* rstd,\n                                 const float* inp, const float* weight, const float* bias,\n                                 int N, int C) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    float eps = 1e-5f;\n\n    if (idx < N) {\n        // seek to the input position inp[idx,:]\n        const float* x = inp + idx * C;\n        // calculate the mean\n        float m = 0.0f;\n        for (int i = 0; i < C; i++) {\n            m += x[i];\n        }\n        m = m / C;\n        // calculate the variance (without any bias correction)\n        float v = 0.0f;\n        for (int i = 0; i < C; i++) {\n            float xshift = x[i] - m;\n            v += xshift * xshift;\n        }\n        v = v / C;\n        // calculate the rstd\n        float s = 1.0f / sqrtf(v + eps);\n        // seek to the output position in out[idx,:]\n        float* out_idx = out + idx * C;\n        for (int i = 0; i < C; i++) {\n            float n = (s * (x[i] - m)); // normalized output\n            float o = n * weight[i] + bias[i]; // scale and shift it\n            out_idx[i] = o; // write\n        }\n        // cache the mean and rstd for the backward pass later\n        mean[idx] = m;\n        rstd[idx] = s;\n    }\n}\n\n__global__ void mean_kernel(float* mean, const float* inp, int N, int C, int block_size) {\n    extern __shared__ float shared[];\n    int idx = blockIdx.x; // range [0, B*T)\n    int tid = threadIdx.x; // range [0, block_size)\n    const float* x = inp + idx * C;\n    // thread coarsening\n    float sum = 0.0f;\n    for (int i = tid; i < C; i += block_size) {\n        sum += x[i];\n    }\n    shared[tid] = sum;\n    __syncthreads();\n    // reductions\n    for (int stride = block_size / 2; stride >= 1; stride /= 2) {\n        __syncthreads();\n        if (tid < stride) {\n            shared[tid] += shared[tid + stride];\n        }\n    }\n    // write the final result (at thread 0) to global memory\n    if (tid == 0) {\n        mean[idx] = shared[0] / C;\n    }\n}\n\n__global__ void rstd_kernel(float* rstd, const float* inp, const float* mean, int N, int C, int block_size) {\n    extern __shared__ float shared[];\n    int idx = blockIdx.x; // range [0, B*T)\n    int tid = threadIdx.x; // range [0, block_size)\n    const float* x = inp + idx * C;\n    float m = mean[idx];\n    // thread coarsening\n    float sum = 0.0f;\n    for (int i = tid; i < C; i += block_size) {\n        float diff = x[i] - m;\n        sum += diff * diff;\n    }\n    shared[tid] = sum;\n    __syncthreads();\n    // reductions\n    for (int stride = block_size / 2; stride >= 1; stride /= 2) {\n        __syncthreads();\n        if (tid < stride) {\n            shared[tid] += shared[tid + stride];\n        }\n    }\n    // write the final result (at thread 0) to global memory\n    if (tid == 0) {\n        rstd[idx] = 1.0f / sqrtf(shared[0] / C + 1e-5f);\n    }\n}\n\n__global__ void normalization_kernel(float* out, const float* inp, float* mean, float* rstd,\n                                     const float* weight, const float* bias, int B, int T, int C) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n    int bt = idx / C;\n    int c = idx % C;\n\n    float m = mean[bt];\n    float s = rstd[bt];\n    float xi = inp[idx];\n    float n = s * (xi - m);\n    float o = n * weight[c] + bias[c];\n\n    out[idx] = o;\n}\n\n__global__ void layernorm_forward_kernel3(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd,\n                                    const float*  __restrict__ inp, const float*  __restrict__ weight,\n                                    const float* __restrict__ bias, int N, int C) {\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    // meta_group_size is the number of warps in a block, and meta_group_rank is the warp index\n    int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n    if(idx >= N) {\n        return;\n    }\n\n    // the row of input that this group of threads is responsible for\n    const float* x = inp + idx * C;\n\n    // mean\n    float sum = 0.0f;\n    for (int i = warp.thread_rank(); i < C; i += warp.size()) {\n        sum += x[i];\n    }\n    sum = cg::reduce(warp, sum, cg::plus<float>{});\n    float m = sum / C;\n    if(warp.thread_rank() == 0 && mean != nullptr) {\n        __stcs(mean + idx, m);\n    }\n\n    // rstd\n    sum = 0.0f;\n    for (int i = warp.thread_rank(); i < C; i += warp.size()) {\n        float diff = x[i] - m;\n        sum += diff * diff;\n    }\n    sum = cg::reduce(warp, sum, cg::plus<float>{});\n    float s = rsqrtf(sum / C + 1e-5f);\n    if(warp.thread_rank() == 0 && rstd != nullptr) {\n        __stcs(rstd + idx, s);\n    }\n\n    // final normalization and scaling by weight/bias\n    float* o = out + idx * C;\n    for (int c = warp.thread_rank(); c < C; c += warp.size()) {\n        // load and store using the .cs \"streaming\" hint to the compiler,\n        // indicating that this data will not be reused soon, and can be streamed through the caches\n        // this allows the threads to get more cache-hits for the (shared) weight and bias parameters\n        float n = s * (__ldcs(x+c) - m);\n        __stcs(o+c, n * weight[c] + bias[c]);\n    }\n}\n\n// same as kernel 3 but uses var(x) == mean(x**2) - mean(x)**2\n__global__ void layernorm_forward_kernel4(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd,\n                                    const float*  __restrict__ inp, const float*  __restrict__ weight,\n                                    const float* __restrict__ bias, int N, int C) {\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n    if(idx >= N) {\n        return;\n    }\n\n    // the row of input that this group of threads is responsible for\n    const float* x = inp + idx * C;\n\n    // thread coarsening through the row, reduce the sum in series\n    float sum = 0.0; // stores sum(x)\n    float sum2 = 0.0; // stores sum(x**2)\n    for (int i = warp.thread_rank(); i < C; i += warp.size()) {\n        float xi = x[i];\n        sum += xi;\n        sum2 += xi * xi;\n    }\n    // warp-level reduction at the end\n    sum = cg::reduce(warp, sum, cg::plus<float>{}); // sum(x)\n    sum2 = cg::reduce(warp, sum2, cg::plus<float>{}); // sum(x**2)\n    sum /= C; // mean(x)\n    sum2 /= C; // mean(x**2)\n\n    // mean, var, rstd\n    float m = sum;\n    float var = sum2 - sum * sum;\n    float s = rsqrtf(var + 1e-5f);\n\n    // store the mean, no need to cache it\n    if(warp.thread_rank() == 0 && mean != nullptr) {\n        __stcs(mean + idx, m);\n    }\n    // store the rstd, no need to cache it\n    if(warp.thread_rank() == 0 && rstd != nullptr) {\n        __stcs(rstd + idx, s);\n    }\n    // final normalization and scaling by weight/bias\n    float* o = out + idx * C;\n    for (int c = warp.thread_rank(); c < C; c += warp.size()) {\n        float n = s * (__ldcs(x+c) - m);\n        __stcs(o+c, n * weight[c] + bias[c]);\n    }\n}\n\n// like 4, but in kernel 5 we have each block doing one row, not just a single warp\n__global__ void layernorm_forward_kernel5(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd,\n                                    const float*  __restrict__ inp, const float*  __restrict__ weight,\n                                    const float* __restrict__ bias, int N, int C) {\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    __shared__ float shared_sum[32]; // block_size max is 1024 = 32 * 32 warps\n    __shared__ float shared_sum2[32]; // warps will be writing into shared memeory after warp-reduce\n    int num_warps = blockDim.x / 32;\n    int warp_id = threadIdx.x / 32;\n    int lane_id = threadIdx.x % 32;\n    int idx = blockIdx.x; // simply one block per row\n    // the row of input that this group of threads is responsible for\n    const float* x = inp + idx * C;\n    // thread coarsening through the row, reduce the sum in series\n    float thread_sum = 0.0; // stores sum(x)\n    float thread_sum2 = 0.0; // stores sum(x**2)\n    // for (int i = C + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) {\n    for (int i = threadIdx.x; i < C; i += blockDim.x) {\n        float xi = x[i];\n        thread_sum += xi;\n        thread_sum2 += xi * xi;\n    }\n    // warp-level reduction\n    float warp_sum = cg::reduce(warp, thread_sum, cg::plus<float>{}); // sum(x)\n    float warp_sum2 = cg::reduce(warp, thread_sum2, cg::plus<float>{}); // sum(x**2)\n    // store the warp-level reduction in shared memory (we could have lane_id == 0 guard but not needed)\n    shared_sum[warp_id] = warp_sum;\n    shared_sum2[warp_id] = warp_sum2;\n    __syncthreads();\n    // load results from shared memory to threads, pad with zeros for threads that are out of bounds\n    warp_sum = (lane_id < num_warps) ? shared_sum[lane_id] : 0.0f;\n    warp_sum2 = (lane_id < num_warps) ? shared_sum2[lane_id] : 0.0f;\n    // now reduce the warp-level reductions\n    float block_sum = cg::reduce(warp, warp_sum, cg::plus<float>{}); // sum(x)\n    float block_sum2 = cg::reduce(warp, warp_sum2, cg::plus<float>{}); // sum(x**2)\n    // mean, var, rstd\n    block_sum /= C; // mean(x)\n    block_sum2 /= C; // mean(x**2)\n    float m = block_sum;\n    float var = block_sum2 - m * m;\n    float s = rsqrtf(var + 1e-5f);\n    // store the mean, no need to cache it\n    if(threadIdx.x == 0 && mean != nullptr) {\n        __stcs(mean + idx, m);\n    }\n    // store the rstd, no need to cache it\n    if(threadIdx.x == 0 && rstd != nullptr) {\n        __stcs(rstd + idx, s);\n    }\n    // final normalization and scaling by weight/bias\n    float* o = out + idx * C;\n    for (int i = threadIdx.x; i < C; i += blockDim.x) {\n        float n = s * (__ldcs(x+i) - m);\n        __stcs(o+i, n * weight[i] + bias[i]);\n    }\n}\n\n// Inspired by `fused_residual_forward_kernel5` in fused_residual_forward.cu\n__global__ void layernorm_forward_kernel6(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd,\n                                    const float*  __restrict__ inp, const float*  __restrict__ weight,\n                                    const float* __restrict__ bias, int N, int C) {\n    assert(blockDim.x == WARP_SIZE);\n\n    // load weights and biases into shared memory\n    // do this before we allow any threads to exit!\n    extern __shared__ char params[];\n    // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so\n    // let's keep everything as x128\n    x128* s_weight = reinterpret_cast<x128*>(params);\n    x128* s_bias = reinterpret_cast<x128*>(params) + (C / x128::size);\n    x128* s_in = reinterpret_cast<x128*>(params) + ((2 + threadIdx.y) * C / x128::size);\n\n    int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size;\n    for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) {\n        s_weight[i/x128::size] = load128(weight + i);\n        s_bias[i/x128::size] = load128(bias + i);\n    }\n    __syncthreads();\n\n    int idx = blockIdx.x * blockDim.y + threadIdx.y;\n    if(idx >= N) { return; } // guard\n\n    // adjust pointers to current token\n    inp += idx * C;\n    out += idx * C;\n\n    const float eps = 1e-5f;\n    float sum = 0.0f;\n    for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) {\n        const x128 in_data = load128cs(inp + c);\n        for(int k = 0; k < x128::size; ++k) {\n            sum += (float)in_data[k];\n        }\n        s_in[c / x128::size] = in_data;\n    }\n\n    sum = warpReduceSum(sum);\n    float m = sum / C;\n    float v = 0.f;\n\n    for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) {\n        const x128 in_data = s_in[c / x128::size];\n        for(int k = 0; k < x128::size; ++k) {\n            v += ((float)in_data[k] - m) * ((float)in_data[k] - m);\n        }\n    }\n\n    v = warpReduceSum(v) / C;\n    float s = rsqrtf(v + eps);\n\n    for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) {\n        const x128 in_data = s_in[c / x128::size];\n        const x128 w = s_weight[c / x128::size];\n        const x128 b = s_bias[c / x128::size];\n        x128 out_data;\n        for(int k = 0; k < x128::size; ++k) {\n            float n = s * ((float)in_data[k] - m); // normalized output\n            float o = n * (float)w[k] + (float)b[k]; // scale and shift it\n            out_data[k] = o;\n        }\n\n        store128cs(out + c, out_data);\n    }\n    // cache the mean and rstd for the backward pass later\n    if(threadIdx.x == 0 && mean != nullptr) {\n        __stcs(mean + idx, m);\n    }\n    // store the rstd, no need to cache it\n    if(threadIdx.x == 0 && rstd != nullptr) {\n        __stcs(rstd + idx, s);\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launcher\n\nvoid layernorm_forward1(float* out, float* mean, float* rstd,\n                           const float* inp, const float* weight, const float* bias,\n                           int B, int T, int C,\n                           const int block_size) {\n    const int N = B * T;\n    const int grid_size = ceil_div(N, block_size);\n    layernorm_forward_kernel1<<<grid_size, block_size>>>(out, mean, rstd, inp, weight, bias, N, C);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid layernorm_forward2(float* out, float* mean, float* rstd,\n                       const float* inp, const float* weight, const float* bias,\n                       int B, int T, int C,\n                       const int block_size) {\n    int N = B * T;\n    // in mean and rstd, threads cooperate within blocks via reductions\n    mean_kernel<<<N, block_size, block_size * sizeof(float)>>>(mean, inp, N, C, block_size);\n    cudaCheck(cudaGetLastError());\n    rstd_kernel<<<N, block_size, block_size * sizeof(float)>>>(rstd, inp, mean, N, C, block_size);\n    cudaCheck(cudaGetLastError());\n    // in the normalization, everything just gets flattened out\n    const int block_size2 = 256;\n    const int grid_size = ceil_div(B * T * C, block_size2);\n    normalization_kernel<<<grid_size, block_size2>>>(out, inp, mean, rstd, weight, bias, B, T, C);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid layernorm_forward3(float* out, float* mean, float* rstd,\n                       const float* inp, const float* weight, const float* bias,\n                       int B, int T, int C,\n                       const int block_size) {\n    assert(block_size % 32 == 0);\n    const int N = B * T;\n    const int grid_size = ceil_div(N * 32, block_size);\n    layernorm_forward_kernel3<<<grid_size, block_size>>>(out, mean, rstd, inp, weight, bias, N, C);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid layernorm_forward4(float* out, float* mean, float* rstd,\n                       const float* inp, const float* weight, const float* bias,\n                       int B, int T, int C,\n                       const int block_size) {\n    assert(block_size % 32 == 0);\n    const int N = B * T;\n    const int grid_size = ceil_div(N * 32, block_size);\n    layernorm_forward_kernel4<<<grid_size, block_size>>>(out, mean, rstd, inp, weight, bias, N, C);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid layernorm_forward5(float* out, float* mean, float* rstd,\n                       const float* inp, const float* weight, const float* bias,\n                       int B, int T, int C,\n                       const int block_size) {\n    assert(block_size % 32 == 0);\n    assert(block_size <= 1024);\n    const int N = B * T;\n    const int grid_size = N;\n    layernorm_forward_kernel5<<<grid_size, block_size>>>(out, mean, rstd, inp, weight, bias, N, C);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid layernorm_forward6(float* out, float* mean, float* rstd,\n                       const float* inp, const float* weight, const float* bias,\n                       int B, int T, int C,\n                       int block_size) {\n    assert(block_size % 32 == 0);\n    const int N = B * T;\n    int block_y = block_size / WARP_SIZE;\n    const int grid_size = ceil_div(N, block_y);\n    size_t smem = (2 + block_y) * C * sizeof(float);\n\n    // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute\n    // this may fail, in which case we fall back to the smem free implementation.\n    cudaCheck(cudaGetLastError());\n    auto status = cudaFuncSetAttribute(layernorm_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);\n    cudaGetLastError();\n    if (status == cudaSuccess) {\n        layernorm_forward_kernel6<<<grid_size, dim3(32, block_y), smem>>>(out, mean, rstd, inp, weight, bias, N, C);\n    } else {\n        const int grid_size = N;\n        // fall back to the version without shared memory\n        layernorm_forward_kernel5<<<grid_size, block_size>>>(out, mean, rstd, inp, weight, bias, N, C);\n    }\n    cudaCheck(cudaGetLastError());\n}\n\n// kernel version dispatch\nvoid layernorm_forward(int kernel_num,\n                    float* out, float* mean, float* rstd,\n                    const float* inp, const float* weight, const float* bias,\n                    int B, int T, int C,\n                    const int block_size) {\n    switch (kernel_num) {\n        case 1:\n            layernorm_forward1(out, mean, rstd, inp, weight, bias, B, T, C, block_size);\n            break;\n        case 2:\n            layernorm_forward2(out, mean, rstd, inp, weight, bias, B, T, C, block_size);\n            break;\n        case 3:\n            layernorm_forward3(out, mean, rstd, inp, weight, bias, B, T, C, block_size);\n            break;\n        case 4:\n            layernorm_forward4(out, mean, rstd, inp, weight, bias, B, T, C, block_size);\n            break;\n        case 5:\n            layernorm_forward5(out, mean, rstd, inp, weight, bias, B, T, C, block_size);\n            break;\n        case 6:\n            layernorm_forward6(out, mean, rstd, inp, weight, bias, B, T, C, block_size);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n// ----------------------------------------------------------------------------\n\nint main(int argc, char **argv) {\n    srand(0);\n\n    int B = 8;\n    int T = 1024;\n    int C = 768;\n\n    int deviceIdx = 0;\n    cudaCheck(cudaSetDevice(deviceIdx));\n\n    // create host memory of random numbers\n    float* out = (float*)malloc(B * T * C * sizeof(float));\n    float* mean = (float*)malloc(B * T * sizeof(float));\n    float* rstd = (float*)malloc(B * T * sizeof(float));\n    float* inp = make_random_float(B * T * C);\n    float* weight = make_random_float(C);\n    float* bias = make_random_float(C);\n\n    // move to GPU\n    float* d_out;\n    float* d_mean;\n    float* d_rstd;\n    float* d_inp;\n    float* d_weight;\n    float* d_bias;\n    cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_mean, B * T * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_rstd, B * T * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_weight, C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_bias, C * sizeof(float)));\n    cudaCheck(cudaMemcpy(d_inp, inp, B * T * C * sizeof(float), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_weight, weight, C * sizeof(float), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_bias, bias, C * sizeof(float), cudaMemcpyHostToDevice));\n\n    // read kernel_num from command line\n    int kernel_num = 2;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    int block_sizes[] = {32, 64, 128, 256, 512, 1024};\n\n    layernorm_forward_cpu(out, mean, rstd, inp, weight, bias, B, T, C);\n\n    // check the correctness of the kernel at all block sizes\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        printf(\"Checking block size %d.\\n\", block_size);\n\n        layernorm_forward(kernel_num, d_out, d_mean, d_rstd, d_inp, d_weight, d_bias, B, T, C, block_size);\n\n        validate_result(d_out, out, \"out\", B * T * C, 1e-5f);\n        validate_result(d_mean, mean, \"mean\", B * T, 1e-5f);\n        validate_result(d_rstd, rstd, \"rstd\", B * T, 1e-5f);\n    }\n\n    printf(\"All results match. Starting benchmarks.\\n\\n\");\n\n    // time the kernel at different block sizes\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n\n        int repeat_times = 2000;\n        float elapsed_time = benchmark_kernel(repeat_times, layernorm_forward,\n                                              kernel_num, d_out, d_mean, d_rstd, d_inp, d_weight, d_bias,\n                                              B, T, C, block_size);\n\n        // napkin math: estimate the memory bandwidth achieved\n        // e.g. A100 40GB PCIe is advertised at 1,555GB/s\n        long memory_ops = (2 * B * T * C) * 4; // *4 for float\n        float memory_bandwidth = memory_ops / elapsed_time / 1e6;\n\n        printf(\"block_size %4d | time %.4f ms | bandwidth %.2f GB/s\\n\", block_size, elapsed_time, memory_bandwidth);\n    }\n\n    // free memory\n    free(out);\n    free(mean);\n    free(rstd);\n    free(inp);\n    free(weight);\n    free(bias);\n    cudaCheck(cudaFree(d_out));\n    cudaCheck(cudaFree(d_mean));\n    cudaCheck(cudaFree(d_rstd));\n    cudaCheck(cudaFree(d_inp));\n    cudaCheck(cudaFree(d_weight));\n    cudaCheck(cudaFree(d_bias));\n\n    return 0;\n}"
  },
  {
    "path": "dev/cuda/matmul_backward.cu",
    "content": "/*\nKernels for matmul backward pass.\n\nCompile example:\nnvcc -O3 --use_fast_math -lcublas -lcublasLt -Xcompiler -fopenmp matmul_backward.cu -o matmul_backward\n\nOMP_NUM_THREADS=32 ./matmul_backward 1\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <cublas_v2.h>\n#include <cuda_runtime.h>\n#include <omp.h>\n#include \"common.h\"\n\n// ----------------------------------------------------------------------------\n// CPU code reference\n\nvoid matmul_backward_cpu(float* dinp, float* dweight, float* dbias,\n                     float* dout, float* inp, float* weight,\n                     int B, int T, int C, int OC) {\n    // most of the running time is spent here and in matmul_forward\n    // this backward could be done in a single \"round\" of loops\n    // but that doesn't afford an efficient parallelization strategy\n\n    // backward into inp first, parallelize over B,T\n    #pragma omp parallel for collapse(2)\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            float* dout_bt = dout + b * T * OC + t * OC;\n            float* dinp_bt = dinp + b * T * C + t * C;\n            for (int o = 0; o < OC; o++) {\n                float* wrow = weight + o*C;\n                float d = dout_bt[o];\n                for (int i = 0; i < C; i++) {\n                    dinp_bt[i] += wrow[i] * d;\n                }\n            }\n        }\n    }\n    // backward into weight/bias, parallelize over output channels OC\n    #pragma omp parallel for\n    for (int o = 0; o < OC; o++) {\n        double sum = 0.0;\n        for (int b = 0; b < B; b++) {\n            for (int t = 0; t < T; t++) {\n                float* dout_bt = dout + b * T * OC + t * OC;\n                float* inp_bt = inp + b * T * C + t * C;\n                float* dwrow = dweight + o*C;\n                float d = dout_bt[o];\n                if (dbias != NULL) { sum += d; }\n                for (int i = 0; i < C; i++) {\n                    dwrow[i] += inp_bt[i] * d;\n                }\n            }\n        }\n        if (dbias != NULL){dbias[o] = sum;}\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n\n// naive kernel to backpropagate only the bias, it's just a sum :'(\n__global__ void matmul_backward_bias_kernel_naive(float* dbias, const float* dout, int B, int T, int OC) {\n    int o = blockIdx.x * blockDim.x + threadIdx.x;\n    if (o < OC) {\n        double sum = 0.0;\n        for (int b = 0; b < B; b++) {\n            for (int t = 0; t < T; t++) {\n                sum += dout[b * T * OC + t * OC + o];\n            }\n        }\n        dbias[o] = sum;\n    }\n}\n\n// use shared memory and coarsening + reductions\n__global__ void matmul_backward_bias_kernel_faster(float* dbias, const float* dout, int B, int T, int OC) {\n    extern __shared__ float shared[];\n    int o = blockIdx.x; // range [0, OC)\n    int tid = threadIdx.x; // range [0, block_size)\n    int block_size = blockDim.x;\n    const float* x = dout + o;\n    // thread coarsening\n    double sum = 0.0;\n    for (int i = tid; i < B * T; i += block_size) {\n        sum += x[i * OC];\n    }\n    shared[tid] = (float) sum;\n    __syncthreads();\n    // reductions\n    for (int stride = block_size / 2; stride >= 1; stride /= 2) {\n        __syncthreads();\n        if (tid < stride) {\n            shared[tid] += shared[tid + stride];\n        }\n    }\n    // write the final result (at thread 0) to global memory\n    if (tid == 0) {\n        dbias[o] = shared[0];\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launcher\n\n// version1: simple cuBLAS calls\nvoid matmul_backward1(float* dinp, float* dweight, float* dbias,\n                      float* dout, float* inp, float* weight, float* ones,\n                      int B, int T, int C, int OC) {\n    float alpha = 1.0f;\n    float beta = 1.0f; // note we must use beta = 1.0 so that we do a +=, as we should, because gradients add\n\n    // for reference the API is:\n    // cublasStatus_t cublasSgemm(cublasHandle_t handle,\n    //                        cublasOperation_t transa, cublasOperation_t transb,\n    //                        int m, int n, int k,\n    //                        const float           *alpha,\n    //                        const float           *A, int lda,\n    //                        const float           *B, int ldb,\n    //                        const float           *beta,\n    //                        float           *C, int ldc)\n\n    // recall the forward pass was calculated with alpha = 1.0f, beta = 0.0f as:\n    // cublasSgemm(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, OC, B*T, C, &alpha, weight, C, inp, C, &beta, out, OC);\n\n    // backward to input\n    cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, C, B*T, OC, &alpha, weight, C, dout, OC, &beta, dinp, C));\n    // backward to weight\n    cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, C, OC, B*T, &alpha, inp, C, dout, OC, &beta, dweight, C));\n    // backward to bias, if given\n    if (dbias != NULL) {\n\n        // sum over B,T using matrix vector multiplication with cuBLAS\n        // for reference this API is:\n        // cublasStatus_t cublasSgemv(cublasHandle_t handle, cublasOperation_t trans,\n        //                    int m, int n,\n        //                    const float           *alpha,\n        //                    const float           *A, int lda,\n        //                    const float           *x, int incx,\n        //                    const float           *beta,\n        //                    float           *y, int incy)\n        // dout is (B,T,OC), or in 2D terms (B*T, OC)\n        // cublasCheck(cublasSgemv(cublas_handle, CUBLAS_OP_N, B*T, OC, &alpha, dout, B*T, ones, 1, &beta, dbias, 1));\n        // cublasCheck(cublasSgemv(cublas_handle, CUBLAS_OP_T, OC, B*T, &alpha, dout, OC, ones, 1, &beta, dbias, 1));\n\n        // ugh the above isn't working...\n        // let's just do naive calculation for now, fix later\n        // const int block_size=128;\n        // const int grid_size=(OC + block_size - 1) / block_size;\n        // matmul_backward_bias_kernel<<<grid_size, block_size>>>(dbias, dout, B, T, OC);\n\n        // bit faster\n        const int block_size=512;\n        dim3 block_dim(block_size);\n        dim3 grid_dim(OC);\n        size_t shared_mem_size = block_size * sizeof(float);\n        matmul_backward_bias_kernel_faster<<<grid_dim, block_dim, shared_mem_size>>>(dbias, dout, B, T, OC);\n    }\n}\n\nvoid matmul_backward(int kernel_num,\n                     float* dinp, float* dweight, float* dbias,\n                     float* dout, float* inp, float* weight, float* ones,\n                     int B, int T, int C, int OC) {\n    switch (kernel_num) {\n        case 1:\n            matmul_backward1(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n// ----------------------------------------------------------------------------\n\nint main(int argc, char **argv) {\n    srand(0);\n\n    int B = 8;\n    int T = 1024;\n    int C = 768;\n    int OC = 768 * 4; // expansion of 4, e.g. in the MLP\n\n    // set up the device\n    int deviceIdx = 0;\n    cudaCheck(cudaSetDevice(deviceIdx));\n    cudaDeviceProp deviceProp;\n    cudaGetDeviceProperties(&deviceProp, deviceIdx);\n    printf(\"Device %d: %s\\n\", deviceIdx, deviceProp.name);\n\n    // setup cuBLAS and its mathmodes, ensure fp32\n    int enable_tf32 = 0; // use fp32 to get accurate results for checking w.r.t. CPU\n    cublasCheck(cublasCreate(&cublas_handle));\n    printf(\"enable_tf32: %d\\n\", enable_tf32);\n    cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH;\n    cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode));\n\n    // create host memory of random numbers\n    float* dinp = make_zeros_float(B * T * C);\n    float* dweight = make_zeros_float(OC * C);\n    float* dbias = make_zeros_float(OC);\n    float* dout = make_random_float(B * T * OC);\n    float* inp = make_random_float(B * T * C);\n    float* weight = make_random_float(OC * C);\n    float* ones = make_ones_float(OC);\n\n    // move to GPU\n    float* d_dinp;\n    float* d_dweight;\n    float* d_dbias;\n    float* d_dout;\n    float* d_inp;\n    float* d_weight;\n    float* d_ones;\n    cudaCheck(cudaMalloc(&d_dinp, B * T * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_dweight, OC * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_dbias, OC * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_dout, B * T * OC * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_weight, OC * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_ones, OC * sizeof(float)));\n    cudaCheck(cudaMemcpy(d_dinp, dinp, B * T * C * sizeof(float), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_dweight, dweight, OC * C * sizeof(float), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_dbias, dbias, OC * sizeof(float), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_dout, dout, B * T * OC * sizeof(float), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_inp, inp, B * T * C * sizeof(float), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_weight, weight, OC * C * sizeof(float), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_ones, ones, OC * sizeof(float), cudaMemcpyHostToDevice));\n\n    // read kernel_num from command line\n    int kernel_num = 1;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    // calculate the CPU reference\n    matmul_backward_cpu(dinp, dweight, dbias, dout, inp, weight, B, T, C, OC);\n\n    // calculate the GPU version\n    matmul_backward(kernel_num, d_dinp, d_dweight, d_dbias, d_dout, d_inp, d_weight, d_ones, B, T, C, OC);\n\n    // compare\n    printf(\"Checking correctness...\\n\");\n    printf(\"dinp:\\n\");\n    validate_result(d_dinp, dinp, \"dinp\", B * T * C, 1e-3f);\n    printf(\"dweight:\\n\");\n    validate_result(d_dweight, dweight, \"dweight\", OC * C, 1e-3f);\n    printf(\"dbias:\\n\");\n    validate_result(d_dbias, dbias, \"dbias\", OC, 1e-3f);\n    printf(\"All results match.\\n\\n\");\n\n    // now benchmark the kernel\n    int repeat_times = 100;\n    float elapsed_time = benchmark_kernel(repeat_times, matmul_backward, kernel_num,\n                                          d_dinp, d_dweight, d_dbias, d_dout, d_inp, d_weight, d_ones,\n                                          B, T, C, OC);\n    printf(\"time %.4f ms\\n\", elapsed_time);\n\n    // cleanups\n    free(dinp);\n    free(dweight);\n    free(dbias);\n    free(dout);\n    free(inp);\n    free(weight);\n    free(ones);\n    cudaCheck(cudaFree(d_dinp));\n    cudaCheck(cudaFree(d_dweight));\n    cudaCheck(cudaFree(d_dbias));\n    cudaCheck(cudaFree(d_dout));\n    cudaCheck(cudaFree(d_inp));\n    cudaCheck(cudaFree(d_weight));\n    cudaCheck(cudaFree(d_ones));\n    cublasCheck(cublasDestroy(cublas_handle));\n\n    return 0;\n}"
  },
  {
    "path": "dev/cuda/matmul_backward_bias.cu",
    "content": "/*\nKernels for matmul backward pass bias only.\n\nCompile example:\nnvcc -O3 -lcublas -lcublasLt -std=c++17 matmul_backward_bias.cu -lineinfo -o matmul_backward_bias\n\n./matmul_backward_bias 1\n./matmul_backward_bias 2\n./matmul_backward_bias 3\n./matmul_backward_bias 4\n./matmul_backward_bias 5\n\nncu:\nsudo ncu --set full --import-source yes -o bias -f ./matmul_backward_bias 1\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <assert.h>\n#include <cublas_v2.h>\n#include <cuda_runtime.h>\n#include <omp.h>\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n#include <type_traits>\n\n#define ENABLE_BF16\n#include \"common.h\"\n\n\n// ----------------------------------------------------------------------------\n// utility functions\n__host__ __device__ bool isPowerOfTwo(int n) {\n    return (n > 0) && ((n & (n - 1)) == 0);\n}\n\n__host__ __device__ int largestPowerOfTwoLessOrEqual(int n) {\n    // Return the largest power of 2 less than or equal to n\n    if (n < 1) {\n        return 0;\n    }\n\n    while ((n & (n - 1)) > 0) {\n        n = n & (n - 1);\n    }\n\n    return n;\n}\n\n// ----------------------------------------------------------------------------\n// CPU code reference\n\nvoid matmul_backward_bias_cpu(float* dinp, float* dweight, float* dbias,\n                     float* dout, float* inp, float* weight,\n                     int B, int T, int C, int OC) {\n    for (int o = 0; o < OC; o++) {\n        double sum = 0.0;\n        for (int b = 0; b < B; b++) {\n            for (int t = 0; t < T; t++) {\n                float* dout_bt = dout + b * T * OC + t * OC;\n                sum += dout_bt[o];\n            }\n        }\n        dbias[o] = sum;\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n\nfloat* dbias_buffer;\n\n__global__ void matmul_backward_bias_kernel1(floatX* dbias, const floatX* dout, int B, int T, int OC) {\n    extern __shared__ float shared[];\n    int o = blockIdx.x; // range [0, OC)\n    int tid = threadIdx.x; // range [0, block_size)\n    int block_size = blockDim.x;\n    const floatX* x = dout + o;\n    // thread coarsening\n    float sum = 0.0;\n    for (int i = tid; i < B * T; i += block_size) {\n        sum += (float)x[i * OC];\n    }\n    shared[tid] = sum;\n    __syncthreads();\n    // reductions\n    for (int stride = block_size / 2; stride >= 1; stride /= 2) {\n        __syncthreads();\n        if (tid < stride) {\n            shared[tid] += shared[tid + stride];\n        }\n    }\n    // write the final result (at thread 0) to global memory\n    if (tid == 0) {\n        dbias[o] = (floatX)((float)dbias[o] + shared[0]);\n    }\n}\n\n// cooperative groups solution, one warp per output channel\n__global__ void matmul_backward_bias_kernel2(floatX* dbias, const floatX* dout, int B, int T, int OC) {\n    // dout is (B, T, OC), dbias is (OC)\n    // e.g. if block_size = 128, then we have 4 warps per block, each in charge of one output channel\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    // meta_group_size is the number of warps in a block (e.g. 4), meta_group_rank is the warp index (0,1,2,3)\n    int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n    if(idx >= OC) { return; }\n    int BT = B * T; // number of elements to reduce in total, per channel\n    // first, thread coarsening to sum reduce the problem size from B*T to 32\n    float sum = 0.0f;\n    for(int i = warp.thread_rank(); i < BT; i += warp.size()) {\n        sum += (float)dout[i * OC + idx];\n    }\n    // now do a warp-level reduce to get the sum across the 32 threads in this warp\n    sum = cg::reduce(warp, sum, cg::plus<float>{});\n    // write the result to output (global memory)\n    if(warp.thread_rank() == 0) {\n        dbias[idx] = (float)dbias[idx] + sum;\n    }\n}\n\n__global__ void matmul_backward_bias_kernel3(floatX* dbias, const floatX* dout, int B, int T, int OC) {\n    // dout is (B, T, OC), dbias is (OC)\n    // in this version of the kernel the entire block of block_size is dedicated to one output channel\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    __shared__ float shared_sum[32]; // block_size max is 1024 = 32 * 32 warps\n    int BT = B * T; // number of elements to reduce in total, per channel\n    int num_warps = blockDim.x / 32;\n    int warp_id = threadIdx.x / 32;\n    int lane_id = threadIdx.x % 32;\n    int idx = blockIdx.x; // simply one block per row\n    // round 1: thread coarsening to reduce the problem size from B*T to block_size\n    float thread_sum = 0.0f;\n    for(int i = threadIdx.x; i < BT; i += blockDim.x) {\n        thread_sum += (float)dout[i * OC + idx];\n    }\n    // now do a warp-level reduce to get the sum across the 32 threads in each warp\n    // reduce the problem size from block_size to block_size/32 i.e. `num_warps`\n    float warp_sum = cg::reduce(warp, thread_sum, cg::plus<float>{});\n    // store the warp sum in shared memory (we could have lane_id == 0 guard but not needed)\n    shared_sum[warp_id] = warp_sum;\n    __syncthreads();\n    // load results from shared memory to threads, pad with zeros for threads that are out of bounds\n    warp_sum = (lane_id < num_warps) ? shared_sum[lane_id] : 0.0f;\n    // now reduce the warp-level reductions\n    float block_sum = cg::reduce(warp, warp_sum, cg::plus<float>{}); // sum(x)\n    // write the result to output (global memory)\n    if(threadIdx.x == 0) {\n        dbias[idx] = (float)dbias[idx] + block_sum;\n    }\n}\n\n// this kernel performs a column-wise reduction over dout, in PyTorch equivalent to:\n// dbias = dout.sum((0,1))\n// the idea is to employ one block to reduce along several columns,\n// where each block has a width of 32 columns to ensure coalesced access.\n// at the end we accumulate the reductions performed by the warps in each block via shared memory\n__global__ void matmul_backward_bias_kernel4(floatX* dbias, const floatX* dout, int B, int T, int OC) {\n    // this kernel is launched with 1D grid_dim of OC/32\n    // for example let's say block_size is 128\n    extern __shared__ float smem[]; // of size block_size (128)\n    const int warp_id = threadIdx.x / warpSize; // warp index in the block, 0,1,2,3\n    const int lane_id = threadIdx.x % warpSize; // thread index in the warp, 0,1,2,...,31\n    const int tl = blockIdx.x * warpSize; // pointer to the start column for this block\n    const int vstep = blockDim.x / warpSize; // number of warps in a block, e.g. 4\n\n    // pointer to the start of the column for one lane of threads\n    // so e.g. 4 (`vstep`) threads (of the same lane_id) will reduce this one column\n    const floatX* dout_col = dout + tl + lane_id;\n\n    // column reductions by looping through the rows\n    // each of the 4 threads offsets by its warp_id and then skips by vstep\n    // together these 4 threads cover all B*T rows of this (lane_id) column\n    // importantly, consecutive threads (in threadId) are processing adjacent columns,\n    // leading to a coalesced memory access pattern\n    float dout_sum = 0.0f;\n    for (int row = warp_id; row < B * T; row += vstep) {\n        dout_sum += (float)dout_col[row * OC];\n    }\n    smem[lane_id + warp_id * warpSize] = dout_sum;\n    __syncthreads();\n\n    // warp_id 0 reduces the shared memory column-wise, linearly\n    dout_sum = 0.0f;\n    if (warp_id == 0) {\n        for (int j = 0; j < vstep; j++) {\n            dout_sum += smem[lane_id + j * warpSize];\n        }\n        dbias[tl + lane_id] = (float)dbias[tl + lane_id] + dout_sum;\n    }\n}\n\n#ifndef ENABLE_BF16\n__global__ void matmul_backward_bias_kernel5(floatX* dbias, const floatX* dout, int B, int T, int OC) {\n    int oc = blockIdx.x * blockDim.x + threadIdx.x;\n    if(oc >= OC) return;\n    float sum = 0.0;\n    // grid-wide loop for maximum parallelism\n    for (int i = blockIdx.y; i < B * T; i += gridDim.y) {\n        sum += (float)dout[i * OC + oc];\n    }\n    // and atomically add everything together. atomics within one block are conflict-free!\n    atomicAdd(dbias + oc, sum);\n}\n#endif\n\n\n__global__ void cast_and_add_kernel(floatX* dst, const float* src, size_t n) {\n    // used only for matmul_backward_bias kernel, a little bit embarassing TODO delete later\n    const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < n) { dst[idx] = (floatX)((float)dst[idx] + src[idx]); } // have to += because dbias is a paramater\n}\n\n__global__ void matmul_backward_bias_kernel7(float* dbias, const floatX* dout, int B, int T, int OC, const int block_size) {\n    // note: this kernel reads in floatX, but it writes to float!\n    // this is because we're using atomics, which are super slow in < fp32 precision on < H100 GPUs\n    // so the trick is do fp32 atomics to a buffer, and then copy_and_cast the result to floatX\n    // (this also results in higher accuracy than doing accumulation directly in floatX)\n\n    // see comments in matmul_backward() for an explanation of block/grid dimensions etc.\n    const int block_size_x = 32;\n    const int block_size_y = block_size / block_size_x; // 16\n    const int OC_per_warp = block_size_x * x128::size;  // 256 at BF16\n\n    int local_oc = threadIdx.x * x128::size;\n    int global_oc = blockIdx.x * OC_per_warp + local_oc;\n    float accumulators[x128::size];\n    extern __shared__ float shared[];\n\n    for (int k = 0; k < x128::size; k++) {\n        accumulators[k] = 0.0f;\n    }\n    int thread_id = threadIdx.y * block_size_x + threadIdx.x;\n    for (int idx = thread_id; idx < OC_per_warp; idx += block_size) {\n        shared[idx] = 0.0f;\n    }\n    __syncthreads();\n    if(global_oc < OC) {\n        for (int idx = blockIdx.y*block_size_y + threadIdx.y; idx < B * T; idx += gridDim.y*block_size_y) {\n            x128 packed_dout = load128(dout + global_oc + idx*OC);\n            for (int k = 0; k < x128::size; k++) {\n                accumulators[k] += (float)packed_dout[k];\n            }\n        }\n        // we need to avoid shared memory bank conflicts for the atomicAdd to maximise performance,\n        // so we accumulate in a conflict-free order, then reorder to match the global memory order\n        for (int k = 0; k < x128::size; k++) {\n            atomicAdd(shared + threadIdx.x + (k * block_size_x), accumulators[k]);\n        }\n    }\n    if (threadIdx.y >= x128::size) { return; } // only need this many warps to reorder the data\n    __syncthreads();\n    // read the accumulated values in the conflict-free order\n    int i = threadIdx.x + (threadIdx.y * block_size_x);\n    float tmp = shared[i];\n    __syncthreads();\n    // write them back to shared memory in the global memory order\n    // 8-way bank conflict for BF16 x128, but only 8x per threadblock (rather than 8x per warp)\n    shared[local_oc + threadIdx.y] = tmp;\n    __syncthreads();\n    // now we do a perfectly coalesced atomic add to global memory (1x 128-byte cacheline per warp)\n    if (i + blockIdx.x*OC_per_warp < OC) {\n        atomicAdd(dbias + i + blockIdx.x*OC_per_warp, shared[i]);\n    }\n}\n\n// We want to decrease the amount of channels handled by each block, so that we need fewer across-block reductions.\n// We do this by realizing the following: For scalar memory access, we need to read one element per thread in a warp\n// to read an entire cacheline, but for vectorized memory access, with 128 bit of data per thread, we only need eight\n// threads to fetch a cacheline, which means that we can already operate on a \"depth\" of four within a single warp.\n// => blockDim.x == 4, blockDim.y == 32/4 = 8\n//\ntemplate<typename OutFloat, bool Atomic>\n__global__ void matmul_backward_bias_kernel8(OutFloat* dbias, const floatX* dout, int B, int T, int OC,\n                                             std::bool_constant<Atomic>) {\n    constexpr const int bdx = 4;\n    constexpr const int bdy = 32 / bdx;\n    assert(blockDim.x == bdx);\n    assert(blockDim.y == bdy);\n\n    int warp_d = (int)threadIdx.x;\n    int warp_c = (int)threadIdx.y;\n    int block_d = (int)threadIdx.z;\n\n    const int OC_per_warp = bdy * x128::size;  // 64 at BF16\n\n    int local_oc = warp_c * x128::size;\n    int global_oc = blockIdx.x * OC_per_warp + local_oc;\n\n    int local_bt = warp_d + bdx * block_d;\n    int bt_per_block = bdx * blockDim.z;\n\n    float accumulators[x128::size];\n    for (int k = 0; k < x128::size; k++) {\n        accumulators[k] = 0.0f;\n    }\n\n    if(global_oc < OC) {\n        // sum up over all bt within registers\n        for (int idx = blockIdx.y * bt_per_block + local_bt; idx < B * T; idx += gridDim.y * bt_per_block) {\n            x128 packed_dout = load128(dout + global_oc + idx*OC);\n            for (int k = 0; k < x128::size; k++) {\n                accumulators[k] += (float)packed_dout[k];\n            }\n        }\n    }\n\n    __shared__ float sub_results[x128::size][32][bdy];\n\n    // reduce within-warp results\n    for (int k = 0; k < x128::size; k++) {\n        float v = accumulators[k];\n        v += __shfl_down_sync(0xffffffff, v, 1, 4);\n        v += __shfl_down_sync(0xffffffff, v, 2, 4);\n        if(warp_d == 0) {\n            sub_results[k][block_d][warp_c] = v;\n        }\n    }\n    __syncthreads();\n\n    // block-wide reductions\n    for (int k = block_d; k < x128::size; k += blockDim.z) {\n        float a = 0.f;\n        for (int r = warp_d; r < blockDim.z; r += bdx) {\n            float v = sub_results[k][r][warp_c];\n            v += __shfl_down_sync(0xffffffff, v, 1, 4);\n            v += __shfl_down_sync(0xffffffff, v, 2, 4);\n            a += v;\n        }\n        if(warp_d == 0 && global_oc < OC) {\n            // coalesced, but not cacheline-sized\n            if constexpr (!Atomic) {\n                dbias[global_oc + k] = (OutFloat)(a + (float)dbias[global_oc + k]);\n            } else {\n                atomicAdd(dbias + global_oc + k, a);\n            }\n        }\n    }\n}\n\n// Like kernel 8, but instead of accumulating to the auxiliary buffer, it writes\n// multiple values that need to be summed up in a separate kernel call.\n// If UseAuxBuffer is false, gridDim.y has to be one, and results are added directly\n// to dbias.\ntemplate<typename OutFloat, bool UseAuxBuffer>\n__global__ void matmul_backward_bias_kernel9(OutFloat* dbias, const floatX* dout, int B, int T, int OC,\n                                             std::bool_constant<UseAuxBuffer>) {\n    constexpr const int bdx = 4;\n    constexpr const int bdy = 32 / bdx;\n    assert(blockDim.x == bdx);\n    assert(blockDim.y == bdy);\n\n    int warp_d = (int)threadIdx.x;\n    int warp_c = (int)threadIdx.y;\n    int block_d = (int)threadIdx.z;\n\n    const int OC_per_warp = bdy * x128::size;  // 64 at BF16\n\n    int local_oc = warp_c * x128::size;\n    int global_oc = blockIdx.x * OC_per_warp + local_oc;\n\n    int local_bt = warp_d + bdx * block_d;\n    int bt_per_block = bdx * blockDim.z;\n\n    float accumulators[x128::size];\n    for (int k = 0; k < x128::size; k++) {\n        accumulators[k] = 0.0f;\n    }\n\n    if(global_oc < OC) {\n        // sum up over all bt within registers\n        for (int idx = blockIdx.y * bt_per_block + local_bt; idx < B * T; idx += gridDim.y * bt_per_block) {\n            x128 packed_dout = load128(dout + global_oc + idx*OC);\n            for (int k = 0; k < x128::size; k++) {\n                accumulators[k] += (float)packed_dout[k];\n            }\n        }\n    }\n\n    __shared__ float sub_results[x128::size][32][bdy];\n\n    // reduce within-warp results\n    for (int k = 0; k < x128::size; k++) {\n        float v = accumulators[k];\n        v += __shfl_down_sync(0xffffffff, v, 1, 4);\n        v += __shfl_down_sync(0xffffffff, v, 2, 4);\n        if(warp_d == 0) {\n            sub_results[k][block_d][warp_c] = v;\n        }\n    }\n    __syncthreads();\n\n    // block-wide reductions\n    for (int k = block_d; k < x128::size; k += blockDim.z) {\n        float a = 0.f;\n        for (int r = warp_d; r < blockDim.z; r += bdx) {\n            float v = sub_results[k][r][warp_c];\n            v += __shfl_down_sync(0xffffffff, v, 1, 4);\n            v += __shfl_down_sync(0xffffffff, v, 2, 4);\n            a += v;\n        }\n        if(warp_d == 0 && global_oc < OC) {\n            // coalesced, but not cacheline-sized\n            if constexpr (!UseAuxBuffer) {\n                dbias[global_oc + k] = (OutFloat)(a + (float)dbias[global_oc + k]);\n            } else {\n                dbias[global_oc + k + blockIdx.y * OC] = a;\n            }\n        }\n    }\n}\n\n\n__global__ void reduce_add_sum_kernel(floatX* dst, const float* src, size_t n, size_t m) {\n    const size_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * f128::size;\n    assert(n % x128::size == 0);\n    if (idx < n) {\n        f128 acc;\n        for(int k = 0; k < f128::size; ++k) {\n            acc[k] = 0.f;\n        }\n\n        for(int l = 0; l < m; ++l) {\n            f128 s = load128(src + idx + n * l);\n            for(int k = 0; k < f128::size; ++k) {\n                acc[k] += s[k];\n            }\n        }\n        for(int k = 0; k < f128::size; ++k) {\n            dst[idx + k] = (floatX) ((float)dst[idx + k] + acc[k]);\n        }\n    }\n}\n\n\n// ----------------------------------------------------------------------------\n// kernel launcher\n\n// version1: simple cuBLAS calls\nvoid matmul_backward_bias1(floatX* dbias, const floatX* dout,\n                      int B, int T, int OC, int block_size) {\n    block_size = largestPowerOfTwoLessOrEqual(block_size);\n    assert(isPowerOfTwo(block_size)); // block_size needs to be power of 2 due to the reduction\n    dim3 block_dim(block_size);\n    dim3 grid_dim(OC);\n    size_t shared_mem_size = block_size * sizeof(float);\n    matmul_backward_bias_kernel1<<<grid_dim, block_dim, shared_mem_size>>>(dbias, dout, B, T, OC);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid matmul_backward_bias2(floatX* dbias, const floatX* dout,\n                      int B, int T, int OC, int block_size) {\n    // block_size 512 seems best\n    const int grid_size = ceil_div(OC * 32, block_size);\n    matmul_backward_bias_kernel2<<<grid_size, block_size>>>(dbias, dout, B, T, OC);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid matmul_backward_bias3(floatX* dbias, const floatX* dout,\n                      int B, int T, int OC, int block_size) {\n    // block_size 256 seems best\n    matmul_backward_bias_kernel3<<<OC, block_size>>>(dbias, dout, B, T, OC);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid matmul_backward_bias4(floatX* dbias, const floatX* dout,\n                      int B, int T, int OC, int block_size) {\n    assert(OC % 32 == 0); // OC must be divisible by 32 for this kernel\n    const int grid_size = OC / 32;\n    matmul_backward_bias_kernel4<<<grid_size, block_size, block_size * sizeof(float)>>>(dbias, dout, B, T, OC);\n    cudaCheck(cudaGetLastError());\n}\n\n#ifndef ENABLE_BF16\nvoid matmul_backward_bias5(floatX* dbias, const floatX* dout,\n                      int B, int T, int OC, int block_size) {\n    const int grid_size_x = ceil_div(OC, block_size);\n    const int grid_size_y = max(1, cuda_threads_per_SM * cuda_num_SMs / block_size);\n    matmul_backward_bias_kernel5<<<dim3(grid_size_x, grid_size_y), dim3(block_size)>>>(dbias, dout, B, T, OC);\n    cudaCheck(cudaGetLastError());\n}\n#endif\n\nvoid matmul_backward_bias7(floatX* dbias, const floatX* dout,\n                      int B, int T, int OC, int block_size) {\n    if(block_size < 256) {\n        block_size = 256;\n    }\n    // Each warp is responsible for 32 * \"x128::size\" = 256 OCs at BF16 (OC must be a multiple of 256!)\n    // Block size is 512 threads (16 warps) and we reduce those 16 values into 1 at the end\n    // blockDim.x is 32 --> single warp being responsible for those 256 OCs\n    // blockDim.y is 16 --> 16 parallel independent warps processing the same OCs for different BTs\n    // gridDim.x is OC / 256 --> each block processes 256 OCs\n    // grimDim.y is max(1, (cuda_num_SMs * threads_per_SM) / (512 * gridDim.x)); --> fill up the entire GPU!\n    const int warp_size = 32;\n    const int OC_per_warp = warp_size * x128::size; // 256 at BF16\n    const int block_size_x = 32;\n    const int block_size_y = block_size / block_size_x; // 16\n    const int grid_size_x = ceil_div(OC, OC_per_warp); // e.g. 3 horizontal blocks for 768 OCs at BF16\n    const int grid_size_y = max(1, cuda_threads_per_SM * cuda_num_SMs / (block_size * grid_size_x)); // full GPU!\n\n    assert(block_size_y >= x128::size); // part of the kernel assumes this is large enough to avoid loops\n\n    cudaCheck(cudaMemset(dbias_buffer, 0, OC * sizeof(float)));\n    matmul_backward_bias_kernel7<<<dim3(grid_size_x, grid_size_y),\n        dim3(block_size_x, block_size_y), OC_per_warp * sizeof(float)>>>(dbias_buffer, dout, B, T, OC, block_size);\n    cudaCheck(cudaGetLastError());\n    cast_and_add_kernel<<<ceil_div(OC, 256), 256, 0>>>(dbias, dbias_buffer, OC);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid matmul_backward_bias8(floatX* dbias, const floatX* dout,\n                      int B, int T, int OC, int block_size) {\n    dim3 block_dim = {4, 8, (unsigned)block_size/32};\n    const int OC_per_warp = block_dim.y * x128::size; // 64 at BF16\n    const int grid_size_x = ceil_div(OC, OC_per_warp); // e.g. 12 horizontal blocks for 768 OCs at BF16\n    const int grid_size_y = max(1, cuda_threads_per_SM * cuda_num_SMs / (block_size * grid_size_x)); // full GPU!\n\n    // If we have enough OC that we don't need cross-block reductions, we can skip the bias_buffer accumulation\n    // and write results directly to the output.\n    if(grid_size_y == 1) {\n        matmul_backward_bias_kernel8<<<dim3(grid_size_x, grid_size_y), block_dim>>>(dbias, dout, B, T, OC, std::bool_constant<false>{});\n        cudaCheck(cudaGetLastError());\n    } else {\n        cudaCheck(cudaMemset(dbias_buffer, 0, OC * sizeof(float)));\n        matmul_backward_bias_kernel8<<<dim3(grid_size_x, grid_size_y), block_dim>>>(dbias_buffer, dout, B, T, OC, std::bool_constant<true>{});\n        cudaCheck(cudaGetLastError());\n        cast_and_add_kernel<<<ceil_div(OC, 256), 256, 0>>>(dbias, dbias_buffer, OC);\n        cudaCheck(cudaGetLastError());\n    }\n}\n\n\nvoid matmul_backward_bias9(floatX* dbias, const floatX* dout,\n                           int B, int T, int OC, int block_size) {\n    dim3 block_dim = {4, 8, (unsigned)block_size/32};\n    const int OC_per_warp = block_dim.y * x128::size; // 64 at BF16\n    const int grid_size_x = ceil_div(OC, OC_per_warp); // e.g. 12 horizontal blocks for 768 OCs at BF16\n    const int grid_size_y = max(1, cuda_threads_per_SM * cuda_num_SMs / (block_size * grid_size_x)); // full GPU!\n\n    // If we have enough OC that we don't need cross-block reductions, we can skip the bias_buffer accumulation\n    // and write results directly to the output.\n    if(grid_size_y == 1) {\n        matmul_backward_bias_kernel9<<<dim3(grid_size_x, grid_size_y), block_dim>>>(dbias, dout, B, T, OC, std::bool_constant<false>{});\n        cudaCheck(cudaGetLastError());\n    } else {\n        // kernel 9 overwrites temp buffer, so no need to memset\n        matmul_backward_bias_kernel9<<<dim3(grid_size_x, grid_size_y), block_dim>>>(dbias_buffer, dout, B, T, OC, std::bool_constant<true>{});\n        cudaCheck(cudaGetLastError());\n        reduce_add_sum_kernel<<<ceil_div(OC, 256 * f128::size), 256, 0>>>(dbias, dbias_buffer, OC, grid_size_y);\n        cudaCheck(cudaGetLastError());\n    }\n}\n\nvoid matmul_backward_bias(int kernel_num, floatX* dbias, floatX* dout,\n                     int B, int T, int OC, int block_size) {\n    switch (kernel_num) {\n        case 1:\n            matmul_backward_bias1(dbias, dout, B, T, OC, block_size);\n            break;\n        case 2:\n            matmul_backward_bias2(dbias, dout, B, T, OC, block_size);\n            break;\n        case 3:\n            matmul_backward_bias3(dbias, dout,  B, T, OC, block_size);\n            break;\n        case 4:\n            matmul_backward_bias4(dbias, dout, B, T, OC, block_size);\n            break;\n        case 5:\n#ifndef ENABLE_BF16\n            matmul_backward_bias5(dbias, dout, B, T, OC, block_size);\n#else\n            fprintf(stderr, \"Kernel 5 is only supported for fp32\");\n            exit(1);\n#endif\n            break;\n        case 7:\n            matmul_backward_bias7(dbias, dout, B, T, OC, block_size);\n            break;\n        case 8:\n            matmul_backward_bias8(dbias, dout, B, T, OC, block_size);\n            break;\n        case 9:\n            matmul_backward_bias9(dbias, dout, B, T, OC, block_size);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n// ----------------------------------------------------------------------------\n\nint main(int argc, char **argv) {\n    setup_main();\n\n    int B = 8;\n    int T = 1024;\n    int C = 768;\n    int OC = 768 * 4; // expansion of 4, e.g. in the MLP\n\n    // read kernel_num from command line\n    int kernel_num = 1;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    // create host memory of random numbers\n    float* dbias = make_zeros_float(OC);\n    float* dout = make_random_float(B * T * OC);\n\n    // move to GPU\n    floatX* d_dbias;\n    floatX* d_dout;\n    cudaCheck(cudaMalloc(&d_dbias, OC * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_dout, B * T * OC * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&dbias_buffer, OC * sizeof(float) * 32));\n    cudaCheck(memcpy_convert(d_dbias, dbias, OC));\n    cudaCheck(memcpy_convert(d_dout, dout, B * T * OC));\n\n    // ncu debugging / profiling, do a single call\n    // int block_size_debug;\n    // if (kernel_num == 1) { block_size_debug = 512;\n    // } else if (kernel_num == 2) { block_size_debug = 512;\n    // } else { block_size_debug = 256; }\n    // printf(\"kernel %d, block_size %d\\n\", kernel_num, block_size_debug);\n    // matmul_backward_bias(kernel_num, NULL, NULL, d_dbias, d_dout, NULL, NULL, NULL, B, T, C, OC, block_size_debug);\n    // exit(EXIT_SUCCESS);\n\n    int block_sizes[] = {32, 64, 128, 256, 512, 768, 1024};\n\n    // calculate the CPU reference\n    matmul_backward_bias_cpu(NULL, NULL, dbias, dout, NULL, NULL, B, T, C, OC);\n\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        // memset the bias to zero\n        cudaCheck(cudaMemset(d_dbias, 0, OC * sizeof(floatX)));\n        // calculate the GPU version\n        matmul_backward_bias(kernel_num, d_dbias, d_dout, B, T, OC, block_size);\n        // compare\n        printf(\"Checking correctness...\\n\");\n        float tol = std::is_same_v<floatX, float> ? 5e-3f : 1.0f;\n        validate_result(d_dbias, dbias, \"dbias\", OC, tol);\n        printf(\"All results match for block_size=%d.\\n\\n\", block_size);\n    }\n\n    // now benchmark the kernel\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        int repeat_times = 2000;\n        float elapsed_time = benchmark_kernel(repeat_times, matmul_backward_bias, kernel_num,\n                                            d_dbias, d_dout, B, T, OC, block_size);\n        printf(\"block_size %d time %.4f ms\\n\", block_size, elapsed_time);\n    }\n\n    // cleanups\n    free(dbias);\n    free(dout);\n    cudaCheck(cudaFree(dbias_buffer));\n    cudaCheck(cudaFree(d_dbias));\n    cudaCheck(cudaFree(d_dout));\n\n    return 0;\n}"
  },
  {
    "path": "dev/cuda/matmul_forward.cu",
    "content": "/*\nKernels for matmul forward pass.\nIt's advised to use OpenMP here because the CPU implementation is fairly slow otherwise\n\nCompile example:\nnvcc -O3 --use_fast_math -Xcompiler -fopenmp matmul_forward.cu -o matmul_forward -lcublas -lcublasLt\n\nversion 1 is naive port from CPU code to kernel: parallelizes over B,T, loops over C\nOMP_NUM_THREADS=32 ./matmul_forward 1\n\nversion 2 calls cuBLAS, very fast\nOMP_NUM_THREADS=32 ./matmul_forward 2\n\nversion 3 calls cuBLASLt, should be even faster\nOMP_NUM_THREADS=32 ./matmul_forward 3\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <cublas_v2.h>\n#include <cuda_runtime.h>\n#include <cublasLt.h>\n#include <omp.h>\n#include \"common.h\"\n\n// ----------------------------------------------------------------------------\n// CPU code reference\n\nvoid matmul_forward_cpu(float* out,\n                    const float* inp, const float* weight, const float* bias,\n                    int B, int T, int C, int OC) {\n    // OC is short for \"output channels\"\n    // inp is (B,T,C), weight is (OC, C), bias is (OC)\n    // out will be (B,T,OC)\n    #pragma omp parallel for collapse(2)\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            float* out_bt = out + b * T * OC + t * OC;\n            const float* inp_bt = inp + b * T * C + t * C;\n            for (int o = 0; o < OC; o++) {\n                float val = (bias != NULL) ? bias[o] : 0.0f;\n                const float* wrow = weight + o*C;\n                for (int i = 0; i < C; i++) {\n                    val += inp_bt[i] * wrow[i];\n                }\n                out_bt[o] = val;\n            }\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n\n// kernel 1: naive kernel, every thread handles one output element, direct global memory access\n__global__ void matmul_forward_kernel1(float* out,\n                                       const float* inp, const float* weight, const float* bias,\n                                       int BT, int C, int OC) {\n    // out is (B,T,OC). OC is short for \"output channels\", e.g. OC = 4 * C\n    // inp is (B,T,C), weight is (OC, C), bias is (OC)\n    // in the naive kernel, every thread handles one element of out\n    int bt = blockIdx.x * blockDim.x + threadIdx.x;\n    int oc = blockIdx.y * blockDim.y + threadIdx.y;\n    if (bt < BT && oc < OC) {\n        float val = (bias != NULL) ? bias[oc] : 0.0f;\n        const float* wrow = weight + oc * C;\n        const float* inp_bt = inp + bt * C;\n        for (int i = 0; i < C; i++) {\n            val += inp_bt[i] * wrow[i];\n        }\n        out[bt * OC + oc] = val;\n    }\n}\n\n// is there no better way other than just adding bias with a whole separate kernel?\n// this is a highly memory-bound operation, should be fused into the matmul kernel\n// but i can't seem to find a cuBLAS function that does this\n__global__ void add_bias(float* out, const float* bias, int B, int T, int OC) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    int stride = blockDim.x * gridDim.x;\n    for (int i = idx; i < B * T * OC; i += stride) {\n        int col = i % OC;\n        out[i] += bias[col];\n    }\n}\n\n// kernel 4: semi-efficient handwritten kernel\n// see trimat_forward.cu for some intermediate development steps\n__device__ float4 ld_vec(const float* address) {\n    return *reinterpret_cast<const float4*>(address);\n}\n\n__device__ void st_vec(float* address, float4 val) {\n    *reinterpret_cast<float4*>(address) = val;\n}\n\n__global__ void __launch_bounds__(16*16) matmul_forward_kernel4(float* out,\n                                       const float* inp, const float* weight, const float* bias,\n                                       int C, int OC) {\n    // out is (B,T,OC). OC is short for \"output channels\", e.g. OC = 4 * C\n    // inp is (B,T,C), weight is (OC, C), bias is (OC)\n    // each thread handles 8x8 elements; each block 128 by 128 elements.\n    int oc = 8*(blockIdx.y * blockDim.y + threadIdx.y);\n\n    // buffers to cache chunks of the input matrices\n    __shared__ float lhs_s[128][32];\n    __shared__ float rhs_s[128][32];\n\n    // adjust our pointers for the current block\n    inp += 128 * blockIdx.x * C;\n    weight += 128 * blockIdx.y * C;\n    out += 128 * blockIdx.x * OC + 128 * blockIdx.y;\n\n    float vals[8][8] = {};\n    if(bias != NULL) {\n        for (int i = 0; i < 8; i++) {\n            for (int j = 0; j < 8; j += 4) {\n                float4 b = ld_vec(bias + oc + j);\n                vals[i][j+0] = b.x;\n                vals[i][j+1] = b.y;\n                vals[i][j+2] = b.z;\n                vals[i][j+3] = b.w;\n            }\n        }\n    }\n\n    int si_start = 4*(16 * threadIdx.y + threadIdx.x);\n    for (int so = 0; so < C; so += 32) {\n        __syncthreads();\n        int xmod8 = threadIdx.x % 8;\n        int xby8 = threadIdx.x / 8;\n        int xo = 4 * xmod8;\n        for(int y = 2 * threadIdx.y + xby8; y < 128; y += 32) {\n            st_vec(&lhs_s[y][xo], ld_vec(inp + y * C + so + xo));\n            st_vec(&rhs_s[y][xo], ld_vec(weight + y * C + so + xo));\n        }\n        __syncthreads();\n\n        for (int si = si_start; si < si_start + 32; si += 4) {\n            float4 rhs[8];\n            for (int u = 0; u < 8; ++u) {\n                rhs[u] = ld_vec(&rhs_s[u + 8 * threadIdx.y][si % 32]);\n            }\n\n            for (int ii = 0; ii < 8; ++ii) {\n                float4 lhs = ld_vec(&lhs_s[ii + 8 * threadIdx.x][si % 32]);\n                for (int ji = 0; ji < 8; ++ji) {\n                    vals[ii][ji] += lhs.x * rhs[ji].x;\n                    vals[ii][ji] += lhs.y * rhs[ji].y;\n                    vals[ii][ji] += lhs.z * rhs[ji].z;\n                    vals[ii][ji] += lhs.w * rhs[ji].w;\n                }\n            }\n        }\n    }\n\n    for (int i = 0; i < 8; ++i) {\n        for (int j = 0; j < 8; j += 4) {\n            float4 result;\n            result.x = vals[i][j + 0];\n            result.y = vals[i][j + 1];\n            result.z = vals[i][j + 2];\n            result.w = vals[i][j + 3];\n            st_vec(out + (8*threadIdx.x+i) * OC + 8*threadIdx.y + j, result);\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launcher\n\n// kernel 1 is the most naive matmul kernel\nvoid matmul_forward1(float* out,\n                     const float* inp, const float* weight, const float* bias,\n                     int B, int T, int C, int OC,\n                     const int sqrt_block_size) {\n    // out is (B,T,OC). OC is short for \"output channels\", e.g. OC = 4 * C\n    // inp is (B,T,C), weight is (OC, C), bias is (OC)\n    dim3 gridDim(ceil_div(B * T, sqrt_block_size), ceil_div(OC, sqrt_block_size));\n    dim3 blockDim(sqrt_block_size, sqrt_block_size);\n    matmul_forward_kernel1<<<gridDim, blockDim>>>(out, inp, weight, bias, B*T, C, OC);\n    cudaCheck(cudaGetLastError());\n}\n\n// kernel 2 calls cuBLAS, which should be very efficient\nvoid matmul_forward2(float* out,\n                     const float* inp, const float* weight, const float* bias,\n                     int B, int T, int C, int OC,\n                     const int sqrt_block_size) {\n    // for reference API is:\n    // cublasStatus_t cublasSgemm(cublasHandle_t handle,\n    //                        cublasOperation_t transa, cublasOperation_t transb,\n    //                        int m, int n, int k,\n    //                        const float           *alpha,\n    //                        const float           *A, int lda,\n    //                        const float           *B, int ldb,\n    //                        const float           *beta,\n    //                        float           *C, int ldc)\n    // for us, inp is (B*T, C), weight is (OC, C), out is (B*T, OC)\n    // cuBLAS does C = alpha * A * B + beta * C\n    // where A is mxk, B is kxn, C is mxn\n    // now, because we use row-major storage, cuBLAS (which is column-major) sees our matrices transposed.\n    // algorithmically / in e.g. PyTorch we want to do: out = inp @ weight.T\n    // but because cuBLAS is column-major, we actually want to get it to calculate out.T . Mathematically, this is:\n    // out.T = weight @ inp.T\n    // but again, our variables look transposed, so using the actual weight/inp we have here in this function, this becomes\n    // out.T = weight.T @ inp\n    // so we need to get cuBLAS to calculate weight.T @ inp (the variables here are the actual ones in this function)\n    // => need to call cuBLAS with A = weight, B = inp\n    // => need to call cuBLAS with transa = CUBLAS_OP_T, transb = CUBLAS_OP_N\n\n    const float alpha = 1.0f;\n    const float beta = 0.0f;\n    cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, OC, B*T, C, &alpha, weight, C, inp, C, &beta, out, OC));\n    // and now we still have to add the bias... (ew)\n    if (bias != NULL) {\n        int block_size = sqrt_block_size * sqrt_block_size;\n        int grid_size = ceil_div(OC * B * T, block_size);\n        add_bias<<<grid_size, block_size>>>(out, bias, B, T, OC);\n        cudaCheck(cudaGetLastError());\n    }\n}\n\n// uses cublasLt to fuse the bias and gelu\n// https://docs.nvidia.com/cuda/cublas/#cublasltmatmul\n// https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASLt/LtSgemm/sample_cublasLt_LtSgemm.cu\nvoid matmul_forward3(float* out,\n                     const float* inp, const float* weight, const float* bias,\n                     int B, int T, int C, int OC) {\n    int has_bias = (bias != NULL);\n    int has_gelu = 0;\n\n    // check bias alignment\n    if(((uintptr_t)bias % 16) != 0) {\n        printf(\"Bias pointer is not aligned (cuBLASLt requirement)!\\n\");\n        exit(EXIT_FAILURE);\n    }\n\n    int returnedResults = 0;\n    cublasLtMatmulDesc_t operationDesc;\n    cublasLtMatmulPreference_t preference;\n    cublasLtMatrixLayout_t weightLayout;\n    cublasLtMatrixLayout_t inputLayout;\n    cublasLtMatrixLayout_t outputLayout;\n    cublasLtMatrixLayout_t biasLayout;\n    cublasLtMatmulHeuristicResult_t heuristic;\n\n    // create the operation descriptor\n    cublasOperation_t opNoTranspose = CUBLAS_OP_N;\n    cublasOperation_t opTranspose = CUBLAS_OP_T;\n    cublasLtEpilogue_t epilogueBias = CUBLASLT_EPILOGUE_DEFAULT;\n    if (has_bias && has_gelu) {\n        epilogueBias = CUBLASLT_EPILOGUE_GELU_BIAS;\n    } else if (has_bias) {\n        epilogueBias = CUBLASLT_EPILOGUE_BIAS;\n    } else if (has_gelu) {\n        epilogueBias = CUBLASLT_EPILOGUE_GELU;\n    }\n    cublasCheck(cublasLtMatmulDescCreate(&operationDesc, cublas_compute_type, CUDA_R_32F));\n    cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTranspose, sizeof(opTranspose)));\n    cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opNoTranspose, sizeof(opNoTranspose)));\n    cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogueBias, sizeof(epilogueBias)));\n    cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));\n\n    // define matrix layouts\n    cublasCheck(cublasLtMatrixLayoutCreate(&weightLayout, CUDA_R_32F, C, OC, C));\n    cublasCheck(cublasLtMatrixLayoutCreate(&inputLayout, CUDA_R_32F, C, B*T, C));\n    cublasCheck(cublasLtMatrixLayoutCreate(&outputLayout, CUDA_R_32F, OC, B*T, OC));\n    cublasCheck(cublasLtMatrixLayoutCreate(&biasLayout, CUDA_R_32F, OC, 1, OC));\n\n    // create a preference handle with specified max workspace\n    cublasCheck(cublasLtMatmulPreferenceCreate(&preference));\n    cublasCheck(cublasLtMatmulPreferenceSetAttribute(preference,\n        CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,\n        &cublaslt_workspace_size, sizeof(cublaslt_workspace_size)));\n\n    // find a suitable algorithm\n    cublasCheck(cublasLtMatmulAlgoGetHeuristic(cublaslt_handle, operationDesc,\n        weightLayout, inputLayout, outputLayout, outputLayout,\n        preference, 1, &heuristic, &returnedResults));\n    if (returnedResults == 0) {\n        printf(\"No cuBLASLt algorithm: B: %d, T: %d, C: %d, OC: %d, bias: %d, gelu: %d\\n\",\n            B, T, C, OC, has_bias, has_gelu);\n        exit(EXIT_FAILURE);\n    }\n\n    // call the matmul\n    const float alpha = 1.0f, beta = 0.0f;\n    cublasCheck(cublasLtMatmul(cublaslt_handle, operationDesc,\n        &alpha, weight, weightLayout, inp, inputLayout, &beta,\n        out, outputLayout, out, outputLayout, &heuristic.algo,\n        cublaslt_workspace, cublaslt_workspace_size, 0));\n\n    // cleanups\n    cublasCheck(cublasLtMatmulPreferenceDestroy(preference));\n    cublasCheck(cublasLtMatmulDescDestroy(operationDesc));\n    cublasCheck(cublasLtMatrixLayoutDestroy(weightLayout));\n    cublasCheck(cublasLtMatrixLayoutDestroy(inputLayout));\n    cublasCheck(cublasLtMatrixLayoutDestroy(outputLayout));\n    cublasCheck(cublasLtMatrixLayoutDestroy(biasLayout));\n}\n\n// handwritten, relatively efficient non-tensorcore matmul kernel\nvoid matmul_forward4(float* out,\n                     const float* inp, const float* weight, const float* bias,\n                     int B, int T, int C, int OC,\n                     int sqrt_block_size) {\n    // out is (B,T,OC). OC is short for \"output channels\", e.g. OC = 4 * C\n    // inp is (B,T,C), weight is (OC, C), bias is (OC)\n    sqrt_block_size = 16;\n\n    dim3 gridDim(ceil_div(B * T, 8*sqrt_block_size), ceil_div(OC, 8*sqrt_block_size));\n    dim3 blockDim(sqrt_block_size, sqrt_block_size);\n    matmul_forward_kernel4<<<gridDim, blockDim>>>(out, inp, weight, bias, C, OC);\n    cudaCheck(cudaGetLastError());\n}\n\n// kernel version dispatch\nvoid matmul_forward(int kernel_num,\n                    float* out,\n                    const float* inp, const float* weight, const float* bias,\n                    int B, int T, int C, int OC,\n                    const int sqrt_block_size) {\n    switch (kernel_num) {\n        case 1:\n            matmul_forward1(out, inp, weight, bias, B, T, C, OC, sqrt_block_size);\n            break;\n        case 2:\n            matmul_forward2(out, inp, weight, bias, B, T, C, OC, sqrt_block_size);\n            break;\n        case 3:\n            matmul_forward3(out, inp, weight, bias, B, T, C, OC);\n            break;\n        case 4:\n            matmul_forward4(out, inp, weight, bias, B, T, C, OC, sqrt_block_size);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n// ----------------------------------------------------------------------------\n\nint main(int argc, char **argv) {\n    srand(0);\n\n    int B = 32;\n    int T = 1024;\n    int C = 768;\n    int OC = 768 * 4; // expansion of 4, e.g. in the MLP\n\n    // set up the device\n    int deviceIdx = 0;\n    cudaCheck(cudaSetDevice(deviceIdx));\n    cudaDeviceProp deviceProp;\n    cudaGetDeviceProperties(&deviceProp, deviceIdx);\n    printf(\"Device %d: %s\\n\", deviceIdx, deviceProp.name);\n\n    // setup cuBLAS and cuBLASLt\n    cublasCheck(cublasCreate(&cublas_handle));\n    cublasCheck(cublasLtCreate(&cublaslt_handle));\n    // TF32 precision is equivalent to torch.set_float32_matmul_precision('high')\n    int enable_tf32 = deviceProp.major >= 8 ? 1 : 0;\n    printf(\"enable_tf32: %d\\n\", enable_tf32);\n    cublas_compute_type = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F;\n    cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH;\n    cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode));\n    // setup the (global) cuBLASLt workspace\n    cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size));\n\n    // create host memory of random numbers\n    float* out = (float*)malloc(B * T * OC * sizeof(float));\n    float* inp = make_random_float(B * T * C);\n    float* weight = make_random_float(OC * C);\n    float* bias = make_random_float(OC);\n\n    // move to GPU\n    float* d_out;\n    float* d_inp;\n    float* d_weight;\n    float* d_bias;\n    cudaCheck(cudaMalloc(&d_out, B * T * OC * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_weight, C * OC * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_bias, OC * sizeof(float)));\n    cudaCheck(cudaMemcpy(d_inp, inp, B * T * C * sizeof(float), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_weight, weight, C * OC * sizeof(float), cudaMemcpyHostToDevice));\n    cudaCheck(cudaMemcpy(d_bias, bias, OC * sizeof(float), cudaMemcpyHostToDevice));\n\n    // read kernel_num from command line\n    int kernel_num = 1;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    // first check the correctness of the kernel\n    matmul_forward_cpu(out, inp, weight, bias, B, T, C, OC);\n\n    // time the kernel at different block sizes\n    int sqrt_block_sizes[] = {4, 8, 16, 32};\n\n    for (int j = 0; j < sizeof(sqrt_block_sizes) / sizeof(int); j++) {\n        int sqrt_block_size = sqrt_block_sizes[j];\n        printf(\"Checking block size %d x %d.\\n\", sqrt_block_size, sqrt_block_size);\n        matmul_forward(kernel_num, d_out, d_inp, d_weight, d_bias, B, T, C, OC, sqrt_block_size);\n        validate_result(d_out, out, \"out\", B * T * OC, 1e-1f);\n    }\n\n    printf(\"All results match. Starting benchmarks.\\n\\n\");\n\n    for (int j = 0; j < sizeof(sqrt_block_sizes) / sizeof(int); j++) {\n        int sqrt_block_size = sqrt_block_sizes[j];\n\n        int repeat_times = 100;\n        float elapsed_time = benchmark_kernel(repeat_times, matmul_forward,\n                                              kernel_num, d_out, d_inp, d_weight, d_bias,\n                                              B, T, C, OC, sqrt_block_size);\n\n        // napkin math: estimate the flops achieved\n        // e.g. A100 40GB PCIe is advertised at 19.5 TFLOPS fp32\n        float tflops = (float)B * T * C * OC * 2 / elapsed_time * 1e3f / 1e12f;\n        printf(\"sqrt_block_size %4d | time %.4f ms | tflops %.2f\\n\", sqrt_block_size, elapsed_time, tflops);\n    }\n\n    // free memory\n    free(out);\n    free(inp);\n    free(weight);\n    free(bias);\n    cudaCheck(cudaFree(d_out));\n    cudaCheck(cudaFree(d_inp));\n    cudaCheck(cudaFree(d_weight));\n    cudaCheck(cudaFree(d_bias));\n    cudaCheck(cudaFree(cublaslt_workspace));\n    cublasCheck(cublasDestroy(cublas_handle));\n    cublasCheck(cublasLtDestroy(cublaslt_handle));\n    return 0;\n}"
  },
  {
    "path": "dev/cuda/nccl_all_reduce.cu",
    "content": "/*\n\nA simple test of NCCL capabilities.\nFills a vector with 1s on the first GPU, 2s on the second, etc.\nThen aggregates the values in the resulting vectors.\n\nCompile example:\nnvcc -lmpi -lnccl -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib/ -lcublas -lcublasLt nccl_all_reduce.cu -o nccl_all_reduce\n\nRun on 2 local GPUs (set -np to a different value to change GPU count):\nmpirun -np 2 ./nccl_all_reduce\n\n*/\n\n#include \"common.h\"\n#include <assert.h>\n#include <cuda_runtime.h>\n#include <mpi.h>\n#include <nccl.h>\n#include <stdint.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include <unistd.h>\n\nvoid nccl_check(ncclResult_t status, const char *file, int line) {\n  if (status != ncclSuccess) {\n    printf(\"[NCCL ERROR] at file %s:%d:\\n%s\\n\", file, line,\n           ncclGetErrorString(status));\n    exit(EXIT_FAILURE);\n  }\n}\n#define ncclCheck(err) (nccl_check(err, __FILE__, __LINE__))\n\nvoid mpi_check(int status, const char *file, int line) {\n  if (status != MPI_SUCCESS) {\n    char mpi_error[4096];\n    int mpi_error_len = 0;\n    assert(MPI_Error_string(status, &mpi_error[0], &mpi_error_len) ==\n           MPI_SUCCESS);\n    printf(\"[MPI ERROR] at file %s:%d:\\n%.*s\\n\", file, line, mpi_error_len,\n           mpi_error);\n    exit(EXIT_FAILURE);\n  }\n}\n#define mpiCheck(err) (mpi_check(err, __FILE__, __LINE__))\n\n// Sets a vector to a predefined value\n__global__ void set_vector(float *data, int N, float value) {\n  int i = blockIdx.x * blockDim.x + threadIdx.x;\n\n  // Check for out-of-bounds access\n  if (i < N) {\n    data[i] = value;\n  }\n}\n\nsize_t cdiv(size_t a, size_t b) { return (a + b - 1) / b; }\n\n// Parameters specific to training on multiple GPUs.\ntypedef struct {\n  int process_rank;      // Rank of this process among all MPI processes on all hosts. 0 if no multi-GPU.\n  int num_processes;     // Total number of processes on all hosts. 1 if no multi-GPU.\n  int local_device_idx;  // This process GPU index on current machine. 0 if no multi-GPU.\n  ncclComm_t nccl_comm;  // NCCL communication primitive, used for collective mutli-GPU work.\n} MultiGpuConfig;\n\n// Determine which GPU this process should use.\n// Processes on the same machines use different GPU indicies. Processes on other machines don't.\n// Copied from NCCL examples: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/examples.html#example-2-one-device-per-process-or-thread\nint multi_gpu_get_local_device_idx(int process_rank, int num_processes) {\n  char hostname[1024];\n  hostname[1023] = '\\0';\n  // All processes on the same machine will share the same hostname.\n  gethostname(hostname, 1023);\n  for (int i=0; i < 1024; i++) {\n    if (hostname[i] == '.') {\n        hostname[i] = '\\0';\n        break;\n    }\n  }\n  uint64_t hostname_hash = 5381;\n  for (int c = 0; hostname[c] != '\\0'; c++){ hostname_hash = ((hostname_hash << 5) + hostname_hash) ^ hostname[c]; }\n\n  // Distribute all hostname hashes to all processes.\n  uint64_t* all_hostsname_hashes = (uint64_t*)malloc(num_processes * sizeof(uint64_t));\n  all_hostsname_hashes[process_rank] = hostname_hash;\n  mpiCheck(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_hostsname_hashes, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD));\n\n  // Identify which GPU we need to use.\n  int local_device_idx = 0;\n  for (int current_process = 0; current_process < num_processes; ++current_process) {\n     if (current_process == process_rank) {\n      // Found my gpu, local_device_idx now has my target GPU index.\n      break;\n     }\n     if (all_hostsname_hashes[current_process] == all_hostsname_hashes[process_rank]) {\n      // This process ID runs on the same machine, but it's not me, skip this GPU\n      local_device_idx++;\n     }\n  }\n\n  free(all_hostsname_hashes);\n  return local_device_idx;\n}\n\nMultiGpuConfig multi_gpu_config_init(int *argc, char ***argv) {\n    // Initialize MPI.\n    MultiGpuConfig result;\n    mpiCheck(MPI_Init(argc, argv));\n    mpiCheck(MPI_Comm_rank(MPI_COMM_WORLD, &result.process_rank));\n    mpiCheck(MPI_Comm_size(MPI_COMM_WORLD, &result.num_processes));\n    result.local_device_idx = multi_gpu_get_local_device_idx(result.process_rank, result.num_processes);\n    printf(\"[Process rank %d] Using GPU %d\\n\", result.process_rank, result.local_device_idx);\n    cudaCheck(cudaSetDevice(result.local_device_idx));\n    ncclUniqueId nccl_id;\n    if (result.process_rank == 0) {\n        ncclCheck(ncclGetUniqueId(&nccl_id));\n    }\n    mpiCheck(MPI_Bcast((void *)&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, MPI_COMM_WORLD));\n    ncclCheck(ncclCommInitRank(&result.nccl_comm, result.num_processes, nccl_id, result.process_rank));\n    return result;\n}\n\nvoid multi_gpu_config_free(const MultiGpuConfig* multi_gpu_config) {\n    ncclCommDestroy(multi_gpu_config->nccl_comm);\n    mpiCheck(MPI_Finalize());\n}\n\nfloat get_mean(float *arr, size_t size, int process_rank) {\n  double sum = 0.0;\n  for (size_t i = 0; i < size; ++i) {\n    sum += arr[i];\n  }\n  return sum / size;\n}\n\nint main(int argc, char **argv) {\n  // Some constants\n  const size_t all_reduce_buffer_size = 32 * 1024 * 1024;\n  const size_t threads_per_block = 1024;\n\n  MultiGpuConfig multi_gpu_config = multi_gpu_config_init(&argc, &argv);\n\n  // Allocating buffers on each of the devices.\n  float *all_reduce_buffer;\n  cudaCheck(\n      cudaMalloc(&all_reduce_buffer, all_reduce_buffer_size * sizeof(float)));\n\n  int n_blocks = cdiv(all_reduce_buffer_size, threads_per_block);\n  // Set the allocated memory to a defined value.\n  set_vector<<<n_blocks, threads_per_block>>>(\n      all_reduce_buffer, all_reduce_buffer_size,\n      (float)(multi_gpu_config.process_rank + 1));\n  cudaCheck(cudaGetLastError());\n\n  float *all_reduce_buffer_host =\n      (float *)malloc(all_reduce_buffer_size * sizeof(float));\n\n  cudaCheck(cudaMemcpy(all_reduce_buffer_host, all_reduce_buffer,\n                       sizeof(float) * all_reduce_buffer_size,\n                       cudaMemcpyDeviceToHost));\n\n  printf(\"[Process rank %d] average value before all reduce is %.6f\\n\", multi_gpu_config.process_rank,\n         get_mean(all_reduce_buffer_host, all_reduce_buffer_size,\n                  multi_gpu_config.process_rank));\n\n  float *all_reduce_buffer_recv;\n  cudaCheck(cudaMalloc(&all_reduce_buffer_recv,\n                       all_reduce_buffer_size * sizeof(float)));\n\n  ncclCheck(ncclAllReduce(\n      (const void *)all_reduce_buffer, (void *)all_reduce_buffer_recv,\n      all_reduce_buffer_size, ncclFloat, ncclSum, multi_gpu_config.nccl_comm, 0));\n\n\n  cudaCheck(cudaMemcpy(all_reduce_buffer_host, all_reduce_buffer_recv,\n                       sizeof(float) * all_reduce_buffer_size,\n                       cudaMemcpyDeviceToHost));\n\n  float all_reduce_mean_value = get_mean(all_reduce_buffer_host, all_reduce_buffer_size, multi_gpu_config.process_rank);\n\n  printf(\"[Process rank %d] average value after all reduce is %.6f\\n\", multi_gpu_config.process_rank, all_reduce_mean_value);\n\n  float expected_all_reduce_mean_value = 0.0;\n  for (int i = 0; i != multi_gpu_config.num_processes; ++i) {\n    expected_all_reduce_mean_value += i + 1;\n  }\n  if (abs(expected_all_reduce_mean_value - all_reduce_mean_value) > 1e-5) {\n    printf(\"[Process rank %d] ERROR: Unexpected all reduce value: %.8f, expected %.8f\\n\", multi_gpu_config.process_rank, all_reduce_mean_value, expected_all_reduce_mean_value);\n  } else {\n    printf(\"[Process rank %d] Checked against expected mean value. All good!\\n\", multi_gpu_config.process_rank);\n  }\n\n  free(all_reduce_buffer_host);\n  cudaCheck(cudaFree(all_reduce_buffer));\n  cudaCheck(cudaFree(all_reduce_buffer_recv));\n  multi_gpu_config_free(&multi_gpu_config);\n}\n"
  },
  {
    "path": "dev/cuda/permute.cu",
    "content": "/*\nKernels to demonstrate permute operation.\n\nCompile example:\nnvcc -O3 permute.cu -o permute\n\nThe goal is to permute a 4D matrix from its original shape (dim1, dim2, dim3, dim4) to a new shape (dim4, dim3, dim1, dim2).\n\nBefore permutation, we need to understand how to access elements in a flattened (linear) form of the matrix.\n\nGiven:\n\ndim1 = size of the 1st dimension\ndim2 = size of the 2nd dimension\ndim3 = size of the 3rd dimension\ndim4 = size of the 4th dimension\n\nFor any element in a 4D matrix at position (i1, i2, i3, i4), where:\n\ni1 is the index in dimension 1\ni2 is the index in dimension 2\ni3 is the index in dimension 3\ni4 is the index in dimension 4\n\nIf you find it challenging to calculate the indices i1, i2, i3, and i4, observe the pattern in the index calculations.\nInitially, it might take some time to grasp, but with practice, you'll develop a mental model for it.\n\nTo calculate the indices, use the following formulas:\n\ni1 = (idx / (dim2 * dim3 * dim4)) % dim1;\ni2 = (idx / (dim3 * dim4)) % dim2;\ni3 = (idx / dim4) % dim3;\ni4 = idx % dim4;\n\nPattern Explanation:\nTo find the index for any dimension, divide the thread ID (idx) by the product of all subsequent dimensions.\nThen, perform modulo operation with the current dimension.\n\n\n\nThe linear index in a flattened 1D array is calculated as:\nlinear_idx = i1 × ( dim2 × dim3 × dim4 ) + i2 × ( dim3 × dim4 ) + i3 × dim4 + i4\nThis linear index uniquely identifies the position of the element in the 1D array.\n\nTo permute the matrix, we need to rearrange the indices according to the new shape.\nIn this case, we are permuting from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2).\n\nThe new dimension post permutation will be as follows:\n\ndim1 becomes the new 3rd dimension.\ndim2 becomes the new 4th dimension.\ndim3 becomes the new 2nd dimension.\ndim4 becomes the new 1st dimension.\n\npermuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2;\n\nHere's how this works:\n\ni4 * (dim3 * dim1 * dim2): This accounts for how many complete dim3 × dim1 × dim2 blocks fit before the current i4 block.\ni3 * (dim1 * dim2): This accounts for the offset within the current i4 block, specifying which i3 block we are in.\ni1 * dim2: This accounts for the offset within the current i3 block, specifying which i1 block we are in.\ni2: This gives the offset within the current i1 block.\n\nLastly at the end we store the current value at idx index of the original value to the permuted index in the permuted_matrix.\n\n\n--------------------------------------------------------------------------------------------------------------------------------------------------------\n\nSimilarly we can follow the above approach to permute matrices of any dimensions.\n\n*/\n\n\n#include <cuda_runtime.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include <cmath>\n\n#include \"common.h\"\n\n// CPU function to permute a 4D matrix\nvoid permute_cpu(const float* matrix, float* out_matrix, int dim1, int dim2, int dim3, int dim4) {\n    int total_threads = dim1 * dim2 * dim3 * dim4;\n\n    for (int idx = 0; idx < total_threads; idx++) {\n        // Calculate the 4D indices from the linear index\n        int i1 = (idx / (dim2 * dim3 * dim4)) % dim1;\n        int i2 = (idx / (dim3 * dim4)) % dim2;\n        int i3 = (idx / dim4) % dim3;\n        int i4 = idx % dim4;\n\n        // Compute the new index for the permuted matrix\n        // Transpose from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2)\n        int permuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2;\n        out_matrix[permuted_idx] = matrix[idx];\n    }\n}\n\n// CUDA kernel to permute a 4D matrix\n__global__ void permute_kernel(const float* matrix, float* out_matrix, int dim1, int dim2, int dim3, int dim4) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n    // Ensure index is within bounds\n    if (idx < dim1 * dim2 * dim3 * dim4) {\n        // Calculate the 4D indices from the linear index\n        int i1 = (idx / (dim2 * dim3 * dim4)) % dim1;\n        int i2 = (idx / (dim3 * dim4)) % dim2;\n        int i3 = (idx / dim4) % dim3;\n        int i4 = idx % dim4;\n\n        // Compute the new index for the permuted matrix\n        // Transpose from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2)\n        int permuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2;\n        out_matrix[permuted_idx] = matrix[idx];\n    }\n}\n\n\nint main() {\n    int dim_1 = 24;\n    int dim_2 = 42;\n    int dim_3 = 20;\n    int dim_4 = 32;\n\n    // Set up the device\n    int deviceIdx = 0;\n    cudaSetDevice(deviceIdx);\n    cudaDeviceProp deviceProp;\n    cudaGetDeviceProperties(&deviceProp, deviceIdx);\n    printf(\"Device %d: %s\\n\", deviceIdx, deviceProp.name);\n\n    // Allocate host memory\n    float* matrix = make_random_float(dim_1 * dim_2 * dim_3 * dim_4);\n    float* permuted_matrix = (float*)malloc(dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));\n\n    // Initialize the matrix with random values\n\n    // Allocate device memory\n    float *d_matrix, *d_permuted_matrix;\n    cudaMalloc(&d_matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));\n    cudaMalloc(&d_permuted_matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));\n\n    // Copy matrix from host to device\n    cudaMemcpy(d_matrix, matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float), cudaMemcpyHostToDevice);\n\n    // Perform permutation on CPU\n    clock_t start = clock();\n    permute_cpu(matrix, permuted_matrix, dim_1, dim_2, dim_3, dim_4);\n    clock_t end = clock();\n    double elapsed_time_cpu = (double)(end - start) / CLOCKS_PER_SEC;\n\n    // Define block and grid sizes\n    dim3 blockSize(256);\n    int totalThreads = dim_1 * dim_2 * dim_3 * dim_4;\n    int gridSize = (totalThreads + blockSize.x - 1) / blockSize.x; // Compute grid size\n\n    // Launch CUDA kernel to perform permutation\n    permute_kernel<<<gridSize, blockSize>>>(d_matrix, d_permuted_matrix, dim_1, dim_2, dim_3, dim_4);\n    cudaDeviceSynchronize(); // Ensure kernel execution is complete\n\n    // Verify results\n    printf(\"Checking correctness...\\n\");\n    validate_result(d_permuted_matrix, permuted_matrix, \"permuted_matrix\", dim_1 * dim_2 * dim_3 * dim_4, 1e-5f);\n\n    printf(\"All results match.\\n\\n\");\n    // benchmark kernel\n    int repeat_times = 1000;\n    float elapsed_time = benchmark_kernel(repeat_times, permute_kernel,\n                                          d_matrix, d_permuted_matrix, dim_1, dim_2, dim_3, dim_4\n    );\n    printf(\"time gpu %.4f ms\\n\", elapsed_time);\n    printf(\"time cpu %.4f ms\\n\", elapsed_time_cpu);\n\n    // Free allocated memory\n    free(matrix);\n    free(permuted_matrix);\n    cudaFree(d_matrix);\n    cudaFree(d_permuted_matrix);\n\n    return 0;\n}\n"
  },
  {
    "path": "dev/cuda/residual_forward.cu",
    "content": "/*\nKernels for residual forward pass.\n\nCompile example:\nnvcc -O3 --use_fast_math -lcublas -lcublasLt residual_forward.cu -o residual_forward\n\nversion 1 is naive port from CPU code to kernel\n./residual_forward 1\nversion 2 packs input into 128 bit memory reads\n./residual_forward 2\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <cuda_runtime.h>\n\n#define ENABLE_BF16\n#include \"common.h\"\n\n// ----------------------------------------------------------------------------\n// CPU code reference lol\n\nvoid residual_forward_cpu(float* out, const float* inp1, const float* inp2, int N) {\n    for (int i = 0; i < N; i++) {\n        out[i] = inp1[i] + inp2[i];\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n\n// elementwise ops are nice and ez\n__global__ void residual_forward_kernel1(floatX* out, const floatX* inp1, const floatX* inp2, int N) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < N) {\n        out[idx] = (floatX)((float)inp1[idx] + (float)inp2[idx]);\n    }\n}\n\n__global__ void residual_forward_kernel2(floatX* out, const floatX* inp1, const floatX* inp2, int N) {\n    int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;\n    if (idx < N) {\n        x128 packed_out;\n        x128 packed_inp1 = load128cs(inp1 + idx);\n        x128 packed_inp2 = load128cs(inp2 + idx);\n        for (int k = 0; k < packed_inp1.size; ++k)\n        {\n            packed_out[k] = (floatX)((float)packed_inp1[k] + (float)packed_inp2[k]);\n        }\n        store128(out + idx, packed_out);\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launcher\n\nvoid residual_forward1(floatX* out, const floatX* inp1, const floatX* inp2, int N, const int block_size) {\n    const int grid_size = ceil_div(N, block_size);\n    residual_forward_kernel1<<<grid_size, block_size>>>(out, inp1, inp2, N);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid residual_forward2(floatX* out, const floatX* inp1, const floatX* inp2, int N, const int block_size) {\n    const int grid_size = ceil_div(N, (int)(block_size * x128::size));\n    residual_forward_kernel2<<<grid_size, block_size>>>(out, inp1, inp2, N);\n    cudaCheck(cudaGetLastError());\n}\n\n// kernel version dispatch\nvoid residual_forward(int kernel_num,\n                  floatX* out,\n                  const floatX* inp1,\n                  const floatX* inp2,\n                  int N,\n                  int block_size) {\n    switch (kernel_num) {\n        case 1:\n            residual_forward1(out, inp1, inp2, N, block_size);\n            break;\n        case 2:\n            residual_forward2(out, inp1, inp2, N, block_size);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n// ----------------------------------------------------------------------------\n\nint main(int argc, char **argv) {\n    setup_main();\n\n    int B = 8;\n    int T = 1024;\n    int C = 768;\n\n    // create host memory of random numbers\n    float* out = (float*)malloc(B * T * C * sizeof(float));\n    float* inp1 = make_random_float(B * T * C);\n    float* inp2 = make_random_float(B * T * C);\n\n    // move to GPU\n    floatX* d_out;\n    floatX* d_inp1;\n    floatX* d_inp2;\n    cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_inp1, B * T * C * sizeof(floatX)));\n    cudaCheck(cudaMalloc(&d_inp2, B * T * C * sizeof(floatX)));\n    cudaCheck(memcpy_convert(d_inp1, inp1, B * T * C));\n    cudaCheck(memcpy_convert(d_inp2, inp2, B * T * C));\n\n    // read kernel_num from command line\n    int kernel_num = 1;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    // first check the correctness of the kernel\n    residual_forward_cpu(out, inp1, inp2, B * T * C);\n\n\n    // time the kernel at different block sizes\n    int block_sizes[] = {32, 64, 128, 256, 512, 1024};\n\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        printf(\"Checking block size %d.\\n\", block_size);\n        residual_forward(kernel_num, d_out, d_inp1, d_inp2, B * T * C, block_size);\n#if !defined(ENABLE_BF16) && !defined(ENABLE_FP16)\n        float tol = 1e-5;\n#else\n        float tol = 1e-2f;\n#endif\n        validate_result(d_out, out, \"out\", B * T * C, tol);\n    }\n\n    printf(\"All results match. Starting benchmarks.\\n\\n\");\n\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n\n        int repeat_times = 1000;\n        float elapsed_time = benchmark_kernel(repeat_times, residual_forward,\n                                              kernel_num, d_out, d_inp1, d_inp2, B * T * C, block_size\n                                              );\n\n        // napkin math: estimate the memory bandwidth achieved\n        // for each (B,T,C) output element, we do 2 read and 1 write, 4 bytes each\n        // and e.g. A100 40GB PCIe is advertised at 1,555GB/s\n        long memory_ops = B * T * C * 3 * 4;\n        float memory_bandwidth = memory_ops / elapsed_time / 1e6;\n\n        printf(\"block_size %4d | time %.4f ms | bandwidth %.2f GB/s\\n\", block_size, elapsed_time, memory_bandwidth);\n    }\n\n    // free memory\n    free(out);\n    free(inp1);\n    free(inp2);\n    cudaCheck(cudaFree(d_out));\n    cudaCheck(cudaFree(d_inp1));\n    cudaCheck(cudaFree(d_inp2));\n\n    return 0;\n}\n"
  },
  {
    "path": "dev/cuda/softmax_forward.cu",
    "content": "/*\nKernels for softmax forward pass.\n\nCompile example:\nnvcc -O3 --use_fast_math -lcublas -lcublasLt softmax_forward.cu -o softmax_forward\n\nversion 1 is naive port from CPU code to kernel: parallelizes over B,T, loops over C\n./softmax_forward 1\n\nversion 2 is a fused kernel that parallelizes over all of B,T,C\n./softmax_forward 2\n\nversion 3 uses intra-warp reductions for maxval and sumval, must use block_size=32\n./softmax_forward 3\n\nversion 4 uses both intra-warp reductions and shared memory for inter-warp reductions\nso it can tolerate any block_size % 32 == 0. this is hopefully the most efficient version\n./softmax_forward 4\n\nversion 5 is naive port from CPU code (softmax_online) to kernel: parallelizes over B,T, loops over C\n./softmax_forward 5\n\nversion 6 is softmax_online that parallelizes over all of B,T,C\n./softmax_forward 6\n\nversion 7 is softmax optimized for very large C.\n./softmax_forward 7\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <assert.h>\n#include <cuda_runtime.h>\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n#include \"common.h\"\n\n// ----------------------------------------------------------------------------\n// CPU code reference\n\nvoid softmax_forward_cpu(float* out, const float* inp, int N, int C) {\n    // inp is (N, C)\n    // out is (N, C), each row of inp will get softmaxed\n    for (int i = 0; i < N; i++) {\n        const float* inp_row = inp + i * C;\n        float* out_row = out + i * C;\n\n        float maxval = -INFINITY;\n        for (int j = 0; j < C; j++) {\n            if (inp_row[j] > maxval) {\n                maxval = inp_row[j];\n            }\n        }\n        // Note: since we want to ensure that the CUDA-kernels are accurate,\n        // we do this accumulation in higher precision, so we can be assured\n        // that our ground-truth is of high quality.\n        double sum = 0.0;\n        for (int j = 0; j < C; j++) {\n            out_row[j] = expf(inp_row[j] - maxval);\n            sum += out_row[j];\n        }\n        float norm = 1.f / (float)sum;\n        for (int j = 0; j < C; j++) {\n            out_row[j] *= norm;\n        }\n    }\n}\n\n\n// online version of softmax on CPU from the paper \"Online normalizer calculation for softmax\"\nvoid softmax_forward_online_cpu(float* out, const float* inp, int N, int C) {\n    // inp is (N, C)\n    // out is (N, C), each row of inp will get softmaxed\n    for (int i = 0; i < N; i++) {\n        const float* inp_row = inp + i * C;\n        float* out_row = out + i * C;\n\n        float maxval = -INFINITY;\n        float sum = 0.0f;\n\t\tfor (int j = 0; j < C; j++) {\n\t\t\tfloat maxval_prev = maxval;\n\t\t\tif (inp_row[j] > maxval) {\n\t\t\t\tmaxval = inp_row[j];\n\t\t\t\tsum = sum * expf(maxval_prev - maxval) + expf(inp_row[j] - maxval);\n\t\t\t} else {\n\t\t\t\tsum += expf(inp_row[j] - maxval);\n\t\t\t}\n\t\t}\n\n        for (int j = 0; j < C; j++) {\n            out_row[j] = expf(inp_row[j] - maxval) / sum;\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPU kernels\n\n__global__ void softmax_forward_kernel1(float* out, const float* inp, int N, int C) {\n    // inp is (N, C)\n    // out is (N, C), each row of inp will get softmaxed\n    int i = blockIdx.x * blockDim.x + threadIdx.x;\n    if (i < N) {\n        const float* inp_row = inp + i * C;\n        float* out_row = out + i * C;\n\n        float maxval = -INFINITY;\n        for (int j = 0; j < C; j++) {\n            if (inp_row[j] > maxval) {\n                maxval = inp_row[j];\n            }\n        }\n        double sum = 0.0;\n        for (int j = 0; j < C; j++) {\n            out_row[j] = expf(inp_row[j] - maxval);\n            sum += out_row[j];\n        }\n        for (int j = 0; j < C; j++) {\n            out_row[j] /= (float)sum;\n        }\n    }\n}\n\n__global__ void softmax_forward_kernel2(float* out, const float* inp, int N, int C) {\n    // inp is (N, C)\n    // in each row of C elements, first calculates maxval, then returns expf(val - maxval)\n    extern __shared__ float shared[];\n    int idx = blockIdx.x; // ranges [0, N)\n    int tid = threadIdx.x; // ranges [0, block_size)\n    int block_size = blockDim.x;\n    const float* x = inp + idx * C; // idx-th row of inp\n    // thread coarsening\n    float maxval = -INFINITY;\n    for (int i = tid; i < C; i += block_size) {\n        maxval = fmaxf(maxval, x[i]);\n    }\n    shared[tid] = maxval;\n    // reductions\n    for (int stride = block_size / 2; stride >= 1; stride /= 2) {\n        __syncthreads();\n        if (tid < stride) {\n            shared[tid] = fmaxf(shared[tid], shared[tid + stride]);\n        }\n    }\n    __syncthreads();\n    float offset = shared[0];\n    // compute expf and write the result to global memory\n    for (int i = tid; i < C; i += block_size) {\n        out[idx * C + i] = expf(x[i] - offset);\n    }\n    __syncthreads();\n    // thread coarsening again, for the sum\n    x = out + idx * C; // idx-th row of out\n    float sumval = 0.0f;\n    for (int i = tid; i < C; i += block_size) {\n        sumval += x[i];\n    }\n    shared[tid] = sumval;\n    // reductions\n    for (int stride = block_size / 2; stride >= 1; stride /= 2) {\n        __syncthreads();\n        if (tid < stride) {\n            shared[tid] += shared[tid + stride];\n        }\n    }\n    // broadcast the sum to all threads in the block\n    __syncthreads();\n    float sum = shared[0];\n    // divide the input values by the sum\n    for (int i = tid; i < C; i += block_size) {\n        out[idx * C + i] = x[i] / sum;\n    }\n}\n\n// warp-level reduction for finding the maximum value\n__device__ float warpReduceMax(float val) {\n    for (int offset = 16; offset > 0; offset /= 2) {\n        val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));\n    }\n    return val;\n}\n\n__global__ void softmax_forward_kernel3(float* out, const float* inp, int N, int C) {\n    // kernel must use block size of 32\n    extern __shared__ float shared[];\n    int idx = blockIdx.x;\n    int tid = threadIdx.x;\n    const float* x = inp + idx * C;\n\n    // Thread coarsening and within-warp reduction for maxval\n    float maxval = -INFINITY;\n    for (int i = tid; i < C; i += blockDim.x) {\n        maxval = fmaxf(maxval, x[i]);\n    }\n    maxval = warpReduceMax(maxval);\n\n    // Broadcast maxval within the warp\n    float offset = __shfl_sync(0xFFFFFFFF, maxval, 0);\n\n    // Compute expf and write the result to global memory\n    for (int i = tid; i < C; i += blockDim.x) {\n        out[idx * C + i] = expf(x[i] - offset);\n    }\n\n    // Thread coarsening and within-warp reduction for sumval\n    x = out + idx * C;\n    float sumval = 0.0f;\n    for (int i = tid; i < C; i += blockDim.x) {\n        sumval += x[i];\n    }\n    // No need to broadcast sumval since all threads in the warp will have the same value\n    // (due to the fact that we're using __shfl_xor_sync)\n    sumval = warpReduceSum(sumval);\n\n    // Divide the input values by the sum\n    for (int i = tid; i < C; i += blockDim.x) {\n        out[idx * C + i] = x[i] / sumval;\n    }\n}\n\n__global__ void softmax_forward_kernel4(float* out, const float* inp, int N, int C) {\n    // out is (N, C) just like inp. Each row of inp will get softmaxed.\n    // same as kernel3, but can handle any block size (multiple of 32)\n    // each row of C elements is handled by block_size threads\n    // furthermore, each block_size threads get executed in warps of 32 threads\n\n    // special reduction operations warpReduceMax/warpReduceSum are used for intra-warp reductions\n    // shared memory is used for inter-warp reduction\n    extern __shared__ float shared[];\n    int idx = blockIdx.x;\n    int tid = threadIdx.x;\n    int warpId = threadIdx.x / 32; // warp index within a block\n    int laneId = threadIdx.x % 32; // thread index within a warp\n\n    // the number of warps per block. recall that blockDim.x is block_size\n    int warpsPerBlock = blockDim.x / 32;\n\n    // shared[] must be allocated to have warpsPerBlock elements\n    // those will be used for max and sum values\n    float* max_or_sum_storage = shared;\n\n    // one row of inp, i.e. inp[idx, :] of shape (C,)\n    const float* x = inp + idx * C;\n\n    // first, thread coarsening by directly accessing global memory in series\n    float maxval = -INFINITY;\n    for (int i = tid; i < C; i += blockDim.x) {\n        maxval = fmaxf(maxval, x[i]);\n    }\n    // now within-warp reductions for maxval\n    maxval = warpReduceMax(maxval);\n\n    // the 0th thread of each warp writes the maxval of that warp to shared memory\n    if (laneId == 0) max_or_sum_storage[warpId] = maxval;\n    __syncthreads();\n\n    // now the 0th thread of the block reduces the max values in shared memory, i.e. across warps\n    if (tid == 0) {\n        float val = max_or_sum_storage[tid];\n        for (int i = 1; i < warpsPerBlock; i++) {\n            val = fmaxf(val, max_or_sum_storage[i]);\n        }\n        // store the final max in the first position\n        max_or_sum_storage[0] = val;\n    }\n    __syncthreads();\n    // broadcast the max to all threads\n    float offset = max_or_sum_storage[0];\n\n    // compute expf and write the result to global memory\n    for (int i = tid; i < C; i += blockDim.x) {\n        out[idx * C + i] = expf(x[i] - offset);\n    }\n\n    // okay now we calculated exp(x - max(x))\n    // step 2: sum all the values and divide by the sum\n\n    // thread coarsening for sum\n    x = out + idx * C;\n    float sumval = 0.0f;\n    for (int i = tid; i < C; i += blockDim.x) {\n        sumval += x[i];\n    }\n    // within-warp reduction for sumval\n    sumval = warpReduceSum(sumval);\n\n    // write sumval to shared memory\n    if (laneId == 0) max_or_sum_storage[warpId] = sumval;\n    __syncthreads();\n\n    // inter-thread reduction of sum\n    if (tid == 0) {\n        float val = max_or_sum_storage[tid];\n        for (int i = 1; i < warpsPerBlock; ++i) {\n            val += max_or_sum_storage[i];\n        }\n        max_or_sum_storage[0] = val;\n    }\n    __syncthreads();\n    // broadcast the sum to all threads\n    float sum = max_or_sum_storage[0];\n\n    // divide the whole row by the sum\n    for (int i = tid; i < C; i += blockDim.x) {\n        out[idx * C + i] = x[i] / sum;\n    }\n}\n\n__global__ void softmax_forward_online_kernel1(float* out, const float* inp, int N, int C) {\n    // inp is (N, C)\n    // out is (N, C), each row of inp will get softmaxed\n    int i = blockIdx.x * blockDim.x + threadIdx.x;\n    if (i < N) {\n        const float* inp_row = inp + i * C;\n        float* out_row = out + i * C;\n\n        float maxval = -INFINITY;\n        double sum = 0.0;\n        for (int j = 0; j < C; j++) {\n            float maxval_prev = maxval;\n            float current_val = inp_row[j];\n\t\t\tif (current_val > maxval) {\n\t\t\t\tmaxval = current_val;\n\t\t\t\tsum = sum * expf(maxval_prev - maxval) + expf(current_val - maxval);\n\t\t\t}\n\t\t\telse {\n\t\t\t\tsum += expf(current_val - maxval);\n\t\t\t}\n\t\t}\n\n        for (int j = 0; j < C; j++) {\n            out_row[j] = expf(inp_row[j] - maxval) / sum;\n        }\n    }\n}\n\n// struct for the reduction operation, guarantees 8-byte alignment\nstruct __align__(8) SumMax\n{\n    float maxval;\n    float sum;\n};\n\n// forceinline helps avoid function call overhead\n__device__ __forceinline__ SumMax reduce_sum_max_op(SumMax a, SumMax b) {\n    bool a_bigger = (a.maxval > b.maxval);\n    SumMax bigger_m = a_bigger ? a : b;\n    SumMax smaller_m = a_bigger ? b : a;\n    SumMax res;\n    res.maxval = bigger_m.maxval;\n    res.sum = bigger_m.sum + smaller_m.sum * expf(smaller_m.maxval - bigger_m.maxval);\n    return res;\n}\n\n__global__ void softmax_forward_online_kernel2(float* out, const float* inp, int N, int C) {\n\tnamespace cg = cooperative_groups;\n\tcg::thread_block block = cg::this_thread_block();\n\tcg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n\tint idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n\tif (idx >= N) {\n\t\treturn;\n\t}\n\n\t// one row of inp, i.e. inp[idx, :] of shape (C,)\n\tconst float* x = inp + idx * C;\n\n    // base case for the reduction\n    SumMax sm_partial;\n\tsm_partial.maxval = -INFINITY;\n\tsm_partial.sum = 0.0f;\n\n\t// first, thread coarsening by directly accessing global memory in series\n\tfor (int i = warp.thread_rank(); i < C; i += warp.size()) {\n\t\tsm_partial = reduce_sum_max_op(sm_partial, { x[i], 1.0f });\n\t}\n\n    // second, the reduction\n\tSumMax sm_total = cg::reduce(warp, sm_partial, reduce_sum_max_op);\n\n\t// divide the whole row by the sum\n\tfor (int i = warp.thread_rank(); i < C; i += warp.size()) {\n        // the below is equivalent to\n        // out[idx * C + i] = expf(x[i] - sm_total.maxval) / sm_total.sum;\n        // but uses special instruction that bypasses the cache\n        __stcs(out + idx * C + i, expf(x[i] - sm_total.maxval) / sm_total.sum);\n\t}\n}\n\n__global__ void softmax_forward_kernel7(float* out, const float* inp, int N, int C) {\n    // out is (N, C) just like inp. Each row of inp will get softmaxed.\n    // same as kernel4, but optimised for very large Cs with advanced unrolling\n\n    // The trick is to read into a register array (all indices known at compile time)\n    // and always read UNROLL_FACTOR values to maximise memory level parallelism\n    // even if we would be out of bounds, we set the index to min(C-1, idx)\n    // so we just do some unnecessary reads (obviously bad for small C)\n    // the writes are in a separate loop with a conditional check for out of bounds\n    // making it separate is necessary to convince the compiler to do the right thing\n    const int UNROLL_FACTOR = 8;\n    const int warpsPerBlock = blockDim.x / 32;\n\n    extern __shared__ float shared[];\n    int idx = blockIdx.x;\n    int tid = threadIdx.x;\n    int warpId = threadIdx.x / 32; // warp index within a block\n    int laneId = threadIdx.x % 32; // thread index within a warp\n\n    // shared[] must be allocated to have 2 * warpsPerBlock elements\n    // first half for max values, the second half for sum values\n    float* maxvals = shared;\n    float* sumvals = &shared[warpsPerBlock];\n\n    if (tid >= C) {\n        maxvals[warpId] = -INFINITY;\n        sumvals[warpId] = 0.0f;\n        return;\n    }\n\n    const float* x = inp + idx * C; // input\n    float* y = out + idx * C; // output\n\n    // first, thread coarsening by directly accessing global memory in series\n    float maxval = -INFINITY;\n    for (int i = tid; i < C; i += blockDim.x * UNROLL_FACTOR) {\n        #pragma unroll\n        for (int u = 0; u < UNROLL_FACTOR; u++) {\n            maxval = fmaxf(maxval, x[min(C - 1, i + u*blockDim.x)]);\n        }\n    }\n\n    // now within-warp reductions for maxval\n    maxval = warpReduceMax(maxval);\n    // the 0th thread of each warp writes the maxval of that warp to shared memory\n    if (laneId == 0) maxvals[warpId] = maxval;\n    __syncthreads();\n    // now the 0th thread reduces the maxvals in shared memory, i.e. across warps\n    if (tid == 0) {\n        float val = maxvals[tid];\n        #pragma unroll\n        for (int i = 1; i < warpsPerBlock; i++) {\n            val = fmaxf(val, maxvals[i]);\n        }\n        // store the final max in the first position\n        maxvals[0] = val;\n    }\n    __syncthreads();\n    // broadcast the max to all threads\n    float offset = maxvals[0];\n\n    // compute expf and write the result to global memory\n    // + thread coarsening for sum\n    float sumval = 0.0f;\n    for (int i = tid; i < C; i += blockDim.x * UNROLL_FACTOR) {\n        float reg_array[UNROLL_FACTOR];\n        #pragma unroll\n        for (int u = 0; u < UNROLL_FACTOR; u++) {\n            reg_array[u] = __ldcs(&x[min(C - 1, i + u*blockDim.x)]);\n        }\n        #pragma unroll\n        for (int u = 0; u < UNROLL_FACTOR; u++) {\n            if (i + u*blockDim.x < C) {\n                float output = expf(reg_array[u] - offset);\n                y[min(C - 1, i + u*blockDim.x)] = output; // compiler likes redundant min()?!\n                sumval += output; // combined into the same loop unlike kernel3\n            }\n        }\n    }\n\n    // okay now we calculated exp(x - max(x))\n    // step 2: sum all the values and divide by the sum\n\n    // within-warp reduction for sumval\n    sumval = warpReduceSum(sumval);\n    // write sumval to shared memory\n    if (laneId == 0) sumvals[warpId] = sumval;\n    __syncthreads();\n    // inter-thread reduction of sum\n    if (tid == 0) {\n        float val = sumvals[tid];\n        #pragma unroll\n        for (int i = 1; i < warpsPerBlock; ++i) {\n            val += sumvals[i];\n        }\n        sumvals[0] = val;\n    }\n    __syncthreads();\n    // broadcast the sum to all threads\n    float sum = sumvals[0];\n\n    // divide the whole row by the sum\n    for (int i = tid; i < C; i += blockDim.x * UNROLL_FACTOR) {\n        float reg_array[UNROLL_FACTOR];\n        #pragma unroll\n        for (int u = 0; u < UNROLL_FACTOR; u++) {\n            reg_array[u] = y[min(C - 1, i + u*blockDim.x)];\n        }\n        #pragma unroll\n        for (int u = 0; u < UNROLL_FACTOR; u++) {\n            if (i + u*blockDim.x < C) {\n                y[i + u*blockDim.x] = reg_array[u] / sum;\n            }\n        }\n    }\n}\n\n__global__ void softmax_forward_online_kernel8(float* out, const float* inp, int N, int C) {\n    // online softmax paper: http://arxiv.org/abs/1805.02867\n    // online softmax reduces loops from 3 to 2\n    // which is done by calculating sumval and maxval in one loop\n    const int warpsPerBlock = blockDim.x / warpSize;\n    int tid = threadIdx.x;\n\n    if (tid >= C) {\n        return;\n    }\n\n    int warpId = tid / warpSize;\n    int laneId = tid % warpSize;\n    // one warp one row\n    int row = blockIdx.x * warpsPerBlock + warpId;\n\n    if (row >= N) {\n        return;\n    }\n\n    const float* x = inp + row * C;\n    float* const y = out + row * C;\n\n    // merge calculating maxval and sumval in one loop\n    // which is an arithmetic improvment from online softmax over normal softmax\n    float maxval = -INFINITY, sumval = 0.0f, bigger;\n    for (int i = laneId; i < C; i += warpSize) {\n        // when updating the maxval, dynamically updates the previous sumval by\n        // multiplying e^{previous_maxval - current_maxval}\n        bigger = fmaxf(maxval, x[i]);\n        sumval = sumval * expf(maxval - bigger) + expf(x[i] - bigger);\n        maxval = bigger;\n    }\n\n    // use warp functions instead of cooperative groups for better readibility\n    // calculate the warp wised maxval and sumval\n    float offsetMaxval, offsetSumval;\n    for (int offset = warpSize / 2; offset > 0; offset >>= 1) {\n        __syncwarp();\n        offsetMaxval = __shfl_down_sync(0xFFFFFFFF, maxval, offset);\n        offsetSumval = __shfl_down_sync(0xFFFFFFFF, sumval, offset);\n        if (offsetMaxval > maxval) {\n            sumval *= expf(maxval - offsetMaxval);\n            maxval = offsetMaxval;\n        } else {\n            offsetSumval *= expf(offsetMaxval - maxval);\n        }\n        sumval += offsetSumval;\n    }\n\n    // sync the warp wised maxval and sumval\n    // which are also the maxval and sumval of one row in C\n    maxval = __shfl_sync(0xFFFFFFFF, maxval, 0);\n    sumval = __shfl_sync(0xFFFFFFFF, sumval, 0);\n\n    for (int i = laneId; i < C; i += warpSize) {\n        y[i] = expf(x[i] - maxval) / sumval;\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launcher\n\nvoid softmax_forward1(float* out, const float* inp, int N, int C, const int block_size) {\n    const int grid_size = ceil_div(N, block_size);\n    softmax_forward_kernel1<<<grid_size, block_size>>>(out, inp, N, C);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid softmax_forward2(float* out, const float* inp, int N, int C, const int block_size) {\n    int grid_size = N;\n    size_t shared_mem_size = block_size * sizeof(float);\n    softmax_forward_kernel2<<<grid_size, block_size, shared_mem_size>>>(out, inp, N, C);\n}\n\nvoid softmax_forward3(float* out, const float* inp, int N, int C, int block_size) {\n    block_size = 32; // awkward but ok. this one only works with block size 32\n    int grid_size = N;\n    size_t shared_mem_size = block_size * sizeof(float);\n    softmax_forward_kernel3<<<grid_size, block_size, shared_mem_size>>>(out, inp, N, C);\n}\n\nvoid softmax_forward4(float* out, const float* inp, int N, int C, int block_size) {\n    int grid_size = N;\n    // for each warp in the block we need a float that will be used for both maxval and sumval\n    size_t shared_mem_size = block_size / 32 * sizeof(float);\n    softmax_forward_kernel4<<<grid_size, block_size, shared_mem_size>>>(out, inp, N, C);\n}\n\nvoid softmax_forward_online1(float* out, const float* inp, int N, int C, int block_size) {\n    const int grid_size = ceil_div(N, block_size);\n    softmax_forward_online_kernel1 <<<grid_size, block_size >>> (out, inp, N, C);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid softmax_forward_online2(float* out, const float* inp, int N, int C, int block_size) {\n    const int grid_size = ceil_div(N * 32, block_size);\n    softmax_forward_online_kernel2 <<<grid_size, block_size >>> (out, inp, N, C);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid softmax_forward7(float* out, const float* inp, int N, int C, int block_size) {\n    int grid_size = N;\n    size_t shared_mem_size = 2 * block_size / 32 * sizeof(float);\n    softmax_forward_kernel7<<<grid_size, block_size, shared_mem_size>>>(out, inp, N, C);\n}\n\nvoid softmax_forward_online8(float* out, const float* inp, int N, int C, int block_size) {\n    const int grid_size = ceil_div(N * 32, block_size);\n    softmax_forward_online_kernel8<<<grid_size, block_size>>>(out, inp, N, C);\n    cudaCheck(cudaGetLastError());\n}\n\n// kernel version dispatch\nvoid softmax_forward(int kernel_num, float* out, const float* inp, int N, int C, const int block_size) {\n    switch (kernel_num) {\n        case 1:\n            softmax_forward1(out, inp, N, C, block_size);\n            break;\n        case 2:\n            softmax_forward2(out, inp, N, C, block_size);\n            break;\n        case 3:\n            softmax_forward3(out, inp, N, C, block_size);\n            break;\n        case 4:\n            softmax_forward4(out, inp, N, C, block_size);\n            break;\n        case 5:\n            softmax_forward_online1(out, inp, N, C, block_size);\n            break;\n        case 6:\n            softmax_forward_online2(out, inp, N, C, block_size);\n            break;\n        case 7:\n            softmax_forward7(out, inp, N, C, block_size);\n            break;\n        case 8:\n            softmax_forward_online8(out, inp, N, C, block_size);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n// ----------------------------------------------------------------------------\n\nint main(int argc, char **argv) {\n    srand(0);\n\n    int B = 8;\n    int T = 1024;\n    int V = 50257;\n\n    int deviceIdx = 0;\n    cudaCheck(cudaSetDevice(deviceIdx));\n\n    // create host memory of random numbers\n    float* out = (float*)malloc(B * T * V * sizeof(float));\n    float* inp = make_random_float(B * T * V);\n\n    // make the input less uniformly random: Otherwise, all probabilities will be basically zero,\n    // and the tests are not actually meaningful.\n    const int* outliers = make_random_int(B * T * 3, V);\n    for(int k = 0; k < 3; ++k) {\n        for(int j = 0; j < B * T; ++j) {\n            inp[j * V + outliers[j*3 + k]] *= 20;\n        }\n    }\n\n    // move to GPU\n    float* d_out;\n    float* d_inp;\n    cudaCheck(cudaMalloc(&d_out, B * T * V * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_inp, B * T * V * sizeof(float)));\n    cudaCheck(cudaMemcpy(d_inp, inp, B * T * V * sizeof(float), cudaMemcpyHostToDevice));\n\n    // read kernel_num from command line\n    int kernel_num = 1;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    int block_sizes[] = {32, 64, 128, 256, 512, 1024};\n\n    softmax_forward_cpu(out, inp, B * T, V);\n    {\n        float max_el = -INFINITY;\n        for(int i = 0; i <  B * T * V; ++i) {\n            max_el = max(max_el, out[i]);\n        }\n        assert(max_el > 1e-4);\n        printf(\"Largest output is: %f\\n\", max_el);\n    }\n\n    // first check the correctness of the kernel\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n        printf(\"Checking block size %d.\\n\", block_size);\n        softmax_forward(kernel_num, d_out, d_inp, B * T, V, block_size);\n        validate_result(d_out, out, \"out\", B * T * V, 1e-4f);\n    }\n\n    printf(\"All results match. Starting benchmarks.\\n\\n\");\n\n    // time the kernel at different block sizes\n    for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {\n        int block_size = block_sizes[j];\n\n        int repeat_times = 100;\n        float elapsed_time = benchmark_kernel(repeat_times, softmax_forward,\n                                              kernel_num, d_out, d_inp, B * T, V, block_size\n                                              );\n\n        printf(\"block_size %4d | time %.4f ms | per token %.2f µs\\n\", block_size, elapsed_time, elapsed_time * 1'000 / (B*T));\n    }\n\n    // free memory\n    free(out);\n    free(inp);\n    free((void*)outliers);\n    cudaCheck(cudaFree(d_out));\n    cudaCheck(cudaFree(d_inp));\n\n    return 0;\n}"
  },
  {
    "path": "dev/cuda/trimat_forward.cu",
    "content": "/*\nTriangular matrix multiplication as in autoregressive attention. A short story.\nby @ngc92\n\nCompile:\nnvcc -O3 --use_fast_math -lcublas -lcublasLt trimat_forward.cu -o trimat_forward -lcublas\n\nRun:\n\ncuBLAS baseline kernel\n./trimat_forward 0\n\nnaive\n./trimat_forward 1\n\nregisters\n./trimat_forward 2\n\ntri3\n./trimat_forward 3\n\ntri4\n./trimat_forward 4\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <assert.h>\n#include <float.h>\n#include <cublas_v2.h>\n#include <cuda_runtime.h>\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n#include \"common.h\"\n\nstatic float* d_qkvr;   // scratch for the cublas kernel\n\n/*                    ** Chapter I - Introduction **\n *\n *  You are Trimul. You've always wanted to do fast matrix multiplication, but they said\n *  \"Don't bother, big dumb Cublas is much faster than you!\"\n *  \"I don't need to be faster than Cublas\", you replied, \"I can be smarter. Cublas calculates\n *  the entire matrix, but I need only half. If I'm more than half as fast as Cublas, I'm\n *  going to win.\"\n *\n *  So to prove everyone wrong, you enter the TriMatlon, the most prestigious competition\n *  for anyone paying Attention.\n *\n *  Before you start preparing, lets have a look at the players involved\n *\n *  First, there is the Referee (`trimul_cpu`), slow and ponderous, but producing results\n *  beyond reproof.\n *  Then, there is Cublas. Cublas' mind is so inflexible, it doesn't actually comprehend\n *  what we are trying to do here, so Cublas has brought an assistant (`permute_kernel`)\n *  that translates the competition into a task that it can solve. But once it recognizes\n *  the problem, its muscle memory kicks in, and matrix products are produced faster than\n *  the eye can see. Stuck in its routine, Cublas doesn't realize the task is already\n *  finished with the lower triangle, though.\n *\n *  If you can do without an assistant, and can solve the right task, then that's your opportunity\n *  to shine!\n */\n\n\n// taken from then attention forward pass\nvoid trimul_cpu(float* out, const float* inp,\n                int B, int T, int C, int NH) {\n    // inp shape: (B, T, 3, NH, HS)\n    // out shape: (B, NH, T, T)\n    int C3 = C*3;\n    int HS = C / NH; // head size\n    float scale = 1.0 / sqrtf(HS);\n\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            for (int nh = 0; nh < NH; nh++) {\n                // Q[b][nh][t][:] = inp[b][t][0][nh][:] (where : is the slice operator for hs)\n                const float* query_t = inp + b * T * C3 + t * C3 + nh * HS;\n                // out[b][nh][t][:]\n                float* out_bth = out + b * NH * T * T + nh * T * T + t * T;\n\n                // pass 1: calculate query dot key and maxval\n                for (int t2 = 0; t2 <= t; t2++) {\n                    // K[b][nh][t2][:] = inp[b][t2][1][nh][:]\n                    const float* key_t2 = inp + b * T * C3 + t2 * C3 + nh * HS + C; // +C because it's key\n\n                    // Q[b][nh][t][:] dot K[b][nh][t2][:]\n                    float val = 0.0f;\n                    for (int i = 0; i < HS; i++) {\n                        val += query_t[i] * key_t2[i];\n                    }\n                    val *= scale;\n\n                     // out[b][nh][t][t2] = val\n                    out_bth[t2] = val;\n                }\n                for(int t2 = t + 1; t2 < T; ++t2) {\n                    // causal mask, using NAN to supress warnings -> it could be -inf\n                    // but it doesn't matter because in validate_result we ignore infinities/NANs\n                    out_bth[t2] = NAN;\n                }\n            }\n        }\n    }\n}\n\n__global__ void permute_kernel(float* q, float* k, float* v,\n                               const float* inp,\n                               int B, int T, int NH, int HS) {\n    // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, T, HS)\n    // but instead, we have a single tensor QKV (inp) of shape (B, T, 3, NH, HS)\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n\n    // Q[b][nh][t][hs] = inp[b][t][0][nh][hs]\n\n    if (idx < B * NH * T * HS) {\n        int b = idx / (NH * T * HS);\n        int rest = idx % (NH * T * HS);\n        int nh = rest / (T * HS);\n        rest = rest % (T * HS);\n        int t = rest / HS;\n        int hs = rest % HS;\n\n        int inp_idx = \\\n            (b * T * 3 * NH * HS)\n            +   (t * 3 * NH * HS)\n            +       (0 * NH * HS)\n            +          (nh * HS)\n            +                hs;\n\n        q[idx] = inp[inp_idx];\n        k[idx] = inp[inp_idx + NH * HS];\n        v[idx] = inp[inp_idx + 2 * (NH * HS)];\n    }\n}\n\n\nvoid trimul_cublas(float* preatt,\n                   const float* inp,\n                   int B, int T, int C, int NH) {\n    int HS = C / NH; // head size\n\n    // permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS)\n    float* q, * k, * v;\n    q = d_qkvr + 0 * B * T * C;\n    k = d_qkvr + 1 * B * T * C;\n    v = d_qkvr + 2 * B * T * C;\n    int total_threads = B * NH * T * HS;\n    int num_blocks = ceil_div(total_threads, 256);\n    permute_kernel<<<num_blocks, 256>>>(q, k, v, inp, B, T, NH, HS);\n    cudaCheck(cudaGetLastError());\n\n    // batched matrix multiply with cuBLAS\n    const float alpha = 1.0f / sqrtf(HS);\n    const float beta = 0.0f;\n    // This schedules in parallel B*NH matmuls of shape q@k^t = (T, HS) @ (HS, T) = (T, T).\n    // IMPORTANT NOTE: Cublas uses a column-major (and we use row-major in our codebase) representation,\n    // so this call might look confusing to you if you look at the `cublasSgemmStridedBatched` signature.\n    //\n    // In order to avoid having to do an additional transpose operation after this func call,\n    // we need to pass in K as the first argument and Q as the second argument, which might make you think we're computing K^T @ Q.\n    // That combined with the shapes we got after the permute kernel - (B, NH, T, HS) (I'll omit B, NH for brevity going forward)\n    // and you might think we end up with (HS, T) @ (T, HS) = (HS, HS).\n    // This is not the case. :)\n    //\n    // Cublas sees our row-major matrix (T, HS) as (HS, T), hence we set the lead dimensions to HS (see function signature).\n    // We transpose K and end up computing K^T @ Q = (T, HS) @ (HS, T) = (T, T).\n    // If you were to interpret the above formula K^T @ Q you might think we end up with:\n    // -----------------------------------\n    // k1.dot(q1) k1.dot(q2) ... k1.dot(qT)\n    // k2.dot(q1) k2.dot(q2) ... k2.dot(qT)\n    // ...\n    // kT.dot(q1) kT.dot(q2) ... kT.dot(qT)\n    // -----------------------------------\n    // But as I mentioned, Cublas is column-major!\n    // So given that the dot product is symmetric we can write k1.dot(q1) as q1.dot(k1) and transposing the above\n    // representation we can see what we actually end up with in the row-major format:\n    // -----------------------------------\n    // q1.dot(k1) q1.dot(k2) ... q1.dot(kT)\n    // q2.dot(k1) q2.dot(k2) ... q2.dot(kT)\n    // ...\n    // qT.dot(k1) qT.dot(k2) ... qT.dot(kT)\n    // -----------------------------------\n    // which is exactly what we wanted! :)\n    cublasCheck(cublasSgemmStridedBatched(cublas_handle,\n                                          CUBLAS_OP_T, CUBLAS_OP_N,\n                                          T, T, HS,\n                                          &alpha,\n                                          k, HS, T * HS,\n                                          q, HS, T * HS,\n                                          &beta,\n                                          preatt, T, T * T,\n                                          B * NH));\n}\n\n/*                    ** Chapter II - Getting a Team **\n *\n *  OK, you've registered for the competition, now what to do. TriMatlon is a team competition, so first, you need\n *  to figure out what kind of team you need, and how to organize it. The individual instances and heads of the\n *  problem are completely independent, so you just can send separate teams to work there completely independently.\n *\n *  To figure out how to organize each team, you take out your spyglass (`Nsight Compute`) and look how the Cublas teams\n *  are handling their work.\n *  Turns out, you need 256 athletes in each group, and those handle 128 x 128 of the tasks. They work together in\n *  a tight square formation, 16 wide and 16 deep.\n *\n *  So, you went out and got your 100 000 friends, and split them into groups (`trimul_launcher`). Each group gets\n *  informed about where they should work (`trimul_global`) and goes off to do their thing (`matmul_tri_naive`).\n *  Let's observe how we're doing.\n */\n\n// using creates an alias for a function pointer\nusing matmul_fn_ptr = void(*)(float* p, int PS, const float* k, int KS, const float* q, int QS, int T, int HS, float alpha);\n\ntemplate<matmul_fn_ptr matmul_tri>\n__global__ void __launch_bounds__(256, 2) trimul_global(float* out, const float* inp, int T, int C, int NH) {\n    // skip above the diagonal\n    if(blockIdx.y > blockIdx.x)\n        return;\n\n    // set up indices\n    int C3 = C*3;\n    int HS = C / NH; // head size\n    float scale = 1.0 / sqrtf(HS);\n\n    // we put the \"batch x head\" dimension into the z block index.\n    int b = blockIdx.z / NH;\n    int nh = blockIdx.z % NH;\n\n    // Get the base address for the current batch and head\n    // shapes -> inp (B, T, 3, NH, HS), Q (B, NH, T, HS), K (B, NH, T, HS)\n    const float* q = inp + b * T * C3 + nh * HS;  // Q[b][nh][:][:] = inp[b][:][0][nh][:]\n    const float* k = inp + b * T * C3 + nh * HS + C;  // K[b][nh][:][:] = inp[b][:][1][nh][:]\n    float* r = out + (b*NH + nh)*T*T;  // out[b][nh][:][:]\n\n    // start the multiplication\n    matmul_tri(r, T, k, C3, q, C3, T, HS, scale);\n}\n\ntemplate<matmul_fn_ptr matmul_tri>\nvoid trimul_launcher(float* out, const float* inp, int B, int T, int C, int NH) {\n    // we assume nice shapes here. Let's not make the code a mess by supporting weird shapes that you\n    // wouldn't want to use anyway.\n    assert(T % 128 == 0);\n    // No need to ceil_div, if it's not a multiple of 128, we would get wrong results anyway.\n    trimul_global<matmul_tri><<<dim3(T / 128, T / 128, NH * B), dim3(16, 16)>>>(out, inp, T, C, NH);\n    cudaCheck(cudaGetLastError());\n}\n\n/*                     ** Chapter III - ... **\n *\n *  You go over to the playing field. On one end of the field, there is a huge pile of funnily shaped cookie cutters.\n *  Some in the shape of animals, some in the shape of a landscape. Each group of workers has assigned some runners,\n *  fetching the cookie cutters for them. The workers seem very relaxing, chatting with each other, lounging about.\n *  You focus in on one of them.\n *\n *  He seems to be giving an instruction to a runner, and then turns back to reading a novel. The runner, meanwhile,\n *  crosses the field and back, handing him an elephant shape. Then she's off again to pick up a savannah background.\n *  Having received the two shapes, pressed them into the dough, and makes an elephant-in-the-savannah cookie. He hands\n *  the cutters back to the runner. \"Can you please fetch me an elephant and a jungle next?\"\n *  While she's on her way, he takes a sip off his cocktail.\n *  This time, she's making only one trip, keeping the elephant in her pocket (_Cache_). Still, it seems to take forever.\n *  You keep observing:\n *  - Elephant and zoo\n *  - Elephant and island\n *  ...\n *  - Lion and savannah\n *  - Lion and jungle\n *  - Lion and zoo\n *  ...\n *\n *  The worker has his poor runner fetch the same things over and over again, looking like she's about to faint from exhaustion.\n *  Even though she realizes this and always keeps one of them in her pocket, there is so much running,\n *  and little actual work happening.\n *\n *  Clearly, this isn't going to be effective, so you call a team meeting.\n */\n\n// baseline implementation: 20 ms\n__device__ void matmul_tri_naive(float* p, int PS, const float* k, int KS, const float* q, int QS, int T, int HS, float alpha) {\n    // coordinate system:\n    // | - - - - - > j\n    // |\n    // |\n    // v\n    // i\n    // get coordinates of our block - each thread is responsible for a single 8x8 block.\n    int i_base = 128 * blockIdx.x + 8 * threadIdx.x;\n    int j_base = 128 * blockIdx.y + 8 * threadIdx.y;\n\n    // One more check to skip the upper diagonal in blocks that are on the diagonal.\n    // Note: we deliberately waste some compute on the jagged diagonal i.e. elements that belong\n    // to the upper triangle that should be masked out. This will be ignored due to the causal mask\n    // in the reference CPU implementation when used in the `validate_result` function.\n    // Alternatively this check should be done in the nested for loop below -> if (i > j) return.\n    if(j_base > i_base)\n        return;\n\n    // Simple nested loop that calculates 8x8 results in one thread.\n    for(int io = 0; io < 8; ++io) {\n        int i = i_base + io;\n        for(int jo = 0; jo < 8; ++jo) {\n            int j = j_base + jo;\n            float val = 0;\n            for (int s = 0; s < HS; ++s) {\n                val += q[i * QS + s] * k[j * KS + s];\n            }\n            p[i * PS + j] = val * alpha;\n        }\n    }\n}\n\n/*                     ** Chapter IV - ... **\n *\n *  Each worker is producing 64 combined cookies from 8 animals and 8 landscapes. They send their runners 64 times\n *  to fetch the corresponding shapes. This is terribly inefficient; The runners need a minute or so for each trip,\n *  but making a cookie can be done in just a second.\n *\n *  \"Let's try something different tomorrow: Just get all 16 cookie cutters that you need, and do all 64 combinations\n *  of them! See all this free space on your workbench (_registers_), you can keep them all there for easy access.\"\n *\n *  The next morning, you come back to the field for another practice session. Initially, there is bustling activity\n *  with the runners, picking up 16 shapes for each worker. But then, the workers have to put down their newspapers\n *  and start making cookies. Now there are 64 combinations, so it takes them a full minute.\n *\n *  Not all groups of workers are equally fast. When the first group finishes with all animal-landscape combinations,\n *  they already start asking the runners for the next set of cookie cutters, combining plants and houses. Even though\n *  the workers are much busier than before, they are still spending most of their time just waiting.\n *\n *  Still, instead of being busy for 20 hours, your team is now done with the task in just 3h 30 minutes; already, this\n *  is five times faster.\n *\n *  You think to yourself: \"Why should we stop at 8 x 8 combinations? Lets to 16 x 16, that's only twice the work for\n *  the runners, but four times as much for the actual workers.\"\n *  You head over to the baking area, and make that suggestion to one of your team leaders.\n *  \"In theory, that sounds great\", she agrees, \"but see, we only have limited space on our workbenches (_registers_).\n *  There is still some room left, but we simply cannot bake 256 cookies at the same time, sorry.\"\n *\n *  A different strategy is needed, then.\n */\n\n// reorganize loops to enable data reuse: 3.5 ms\n__device__ void matmul_tri_registers(float* p, int PS, const float* k, int KS, const float* q, int QS, int T, int HS, float alpha) {\n    int i_base = 128 * blockIdx.x + 8 * threadIdx.x;\n    int j_base = 128 * blockIdx.y + 8 * threadIdx.y;\n\n    if (j_base > i_base)\n        return;\n\n    // shift our pointers to the sub-block this thread is responsible for\n    q += i_base * QS;\n    k += j_base * KS;\n    p += i_base * PS + j_base;\n\n    float vals[8][8] = {};\n    for (int hs = 0; hs < HS; ++hs) {\n        float lhs[8];\n        float rhs[8];\n        for (int u = 0; u < 8; ++u) {\n            lhs[u] = q[u * QS + hs];\n            rhs[u] = k[u * KS + hs];\n        }\n\n        for (int i = 0; i < 8; ++i) {\n            for (int j = 0; j < 8; ++j) {\n                vals[i][j] += lhs[i] * rhs[j];\n            }\n        }\n    }\n\n    for (int i = 0; i < 8; ++i) {\n        for (int j = 0; j < 8; ++j) {\n            p[i * PS + j] = vals[i][j] * alpha;\n        }\n    }\n}\n\n/*                     ** Chapter IV - By the Bucketload **\n *\n *  Despite the hectic activity, you pick out one of the runners. \"Why are you always brining just one shape? Wouldn't\n *  it be much more efficient if you took more than one?\"\n *  \"Of course\", the runner answers, \"but they've asked me for an elephant, a lion, a zebra, and a goldfish. These\n *  are all over the place, I can't just pick them up at one spot (_strided acccess_).\n *  \"But the lion is right next to the palm tree. You could bring those two together?\", you confirm.\n *  \"Yes\", he says, \"if they just asked for the different categories at the same time, that would make things\n *  so much easier. See, I have this bucket, I could carry lots of things in one go if I could just scoop them up\n *  from the same place (_coalesced access_).\n *\n *  OK, then lets fetch the first animal, first plant, first vehicle, and first landmark shape in one go (_vectorized load_).\n *  [Here, the metaphor breaks down a bit: Since we're accumulating all the results, getting more data at the same time\n *  depth-wise doesn't require more space on the workbench. We're stacking the cookies!]\n *\n *  You also streamline the shape combination further. Instead of picking up all animals and landscapes at one, it is\n *  more efficient, using less workbench space, to just pick up all animals. Then, you get one landscape, combine it\n *  will all animals, get the next landscape, combine, and so on.\n *\n *  In this way, instead of 2 x 8 x 4 cookie cutters that take up space, you only need (8+1) x 4 at the same time.\n *\n *  With these optimizations, you are down to 100 minutes for this task. Still slower than Cublas, but not by much.\n *\n *  In the arena, each team also has access to a small storage hut, much closer to their workbenches than the piles of\n *  cookie cutters on the other side. Cublas is using them heavily, so maybe you should, too.\n */\n\n// convenient helper functions to make the code below more readable\n__device__ float4 ld_vec(const float* address) {\n    return *reinterpret_cast<const float4*>(address);\n}\n\n__device__ void st_vec(float* address, float4 val) {\n    *reinterpret_cast<float4*>(address) = val;\n}\n\n// vector instructions for coalesced memory access: 1.7 ms\n__device__ void matmul_tri3(float* p, int PS, const float* k, int KS, const float* q, int QS, int T, int HS, float alpha) {\n    // Same logic as previous kernel we just load in float4 to improve coalescing\n    int i_base = 128 * blockIdx.x + 8 * threadIdx.x;\n    int j_base = 128 * blockIdx.y + 8 * threadIdx.y;\n\n    if (j_base > i_base)\n        return;\n\n    // shift our pointers to the sub-block this thread is responsible for\n    q += i_base * QS;\n    k += j_base * KS;\n    p += i_base * PS + j_base;\n\n    float vals[8][8] = {};\n    for (int hs = 0; hs < HS; hs += 4) {\n        // load in float4 to improve coalescing\n        float4 rhs[8];\n        for (int u = 0; u < 8; ++u) {\n            rhs[u] = ld_vec(k + u * KS + hs);\n        }\n\n        for (int i = 0; i < 8; ++i) {\n            // no need to keep lhs around for the i loop, it's only reused in the j loop anyway.\n            float4 lhs = ld_vec(q + i * QS + hs);\n            for (int j = 0; j < 8; ++j) {\n                vals[i][j] += lhs.x * rhs[j].x;\n                vals[i][j] += lhs.y * rhs[j].y;\n                vals[i][j] += lhs.z * rhs[j].z;\n                vals[i][j] += lhs.w * rhs[j].w;\n            }\n        }\n    }\n\n    for (int i = 0; i < 8; ++i) {\n        for (int j = 0; j < 8; j += 4) {\n            float4 result;\n            result.x = vals[i][j + 0] * alpha;\n            result.y = vals[i][j + 1] * alpha;\n            result.z = vals[i][j + 2] * alpha;\n            result.w = vals[i][j + 3] * alpha;\n            st_vec(p + i * PS + j, result);\n        }\n    }\n}\n\n/*                     ** Chapter V - Sharing is Caring **\n *\n *  You take a look around the shed, and see that there are 32 shelves there. They are much larger than the workbenches,\n *  giving you enough space for all the cookie cutters needed by the entire team.\n *\n *  Within the team, workers have banded together in groups of 32. They are always doing the same thing, reducing the\n *  amount of effort required for coordination. However, that also means that if you send them all to pick up different\n *  cookie cutters from the same shelf, they will have to wait and queue up (_shared memory bank conflict_).\n *\n *  In order to achieve maximum efficiency, we send the runners fetching cutters with the maximum bucket size: 32 different\n *  categories at the same time.\n *\n *  [I'm having trouble getting the specifics into the story in a sensible way. For now, please read the code for more\n *  details.]\n *\n */\n__device__ void matmul_tri4(float* p, int PS, const float* k, int KS, const float* q, int QS, int T, int HS, float alpha) {\n    int i_base = 128 * blockIdx.x + 8 * threadIdx.x;\n    int j_base = 128 * blockIdx.y + 8 * threadIdx.y;\n\n    // we need all threads for loading data, so none of them can chicken out early, even\n    // if they are not responsible for any useful result.\n    if (blockIdx.y > blockIdx.x)\n        return;\n\n    q += 128 * blockIdx.x * QS;\n    k += 128 * blockIdx.y * KS;\n\n    __shared__ float lhs_s[128][32];\n    __shared__ float rhs_s[128][32];\n\n    float vals[8][8] = {};\n    for (int so = 0; so < HS; so += 32) {\n        // Read a large slice of the input, worked on together by all threads.\n        // They are organized differently for this part. We want to ensure\n        // fully coalesced loads, so we let a single warp handle consecutive\n        // addresses, which means we need to combine two threadIdx.y values\n        // in one read operation.\n        // note: threads may read data here that they don't need themselves.\n        //       this really is a block-level operation.\n        // note2: 16x16 threads (i.e. the block) will, through this for loop, fetch 32 dims from 128 keys and 128 queries\n        // i.e. from Q/K, of shape (T, HS) take q[:128, so*32:(so+1)*32] and k[:128, so*32:(so+1)*32]\n        __syncthreads();\n        for(int y = threadIdx.y / 2; y < 128; y += 8) {\n            int xo = (threadIdx.y % 2) * 16;\n            lhs_s[y][threadIdx.x + xo] = q[y * QS + so + threadIdx.x + xo];\n            rhs_s[y][threadIdx.x + xo] = k[y * KS + so + threadIdx.x + xo];\n        }\n        __syncthreads();\n\n        // Now we compute a partial dot product (only 32 dims) for all combinations of keys and queries (128x128).\n        // Each thread does 8x8 of these partial dot products.\n        // E.g. thread (0,0) covers queries 0-7 and keys 0-7. More generally first row of threads\n        // (0,:) covers queries 0-7 with keys 0-127 and so on.\n        // In the next iterations of the outer (`so`) loop we'll be accumulating values to `vals` until we\n        // get the full dot product. We then later deposit it into the output matrix for all 8x8 blocks\n        // that are below the diagonal.\n        for (int si = 0; si < 32; ++si) {\n            float rhs[8];\n            for (int u = 0; u < 8; ++u) {\n                rhs[u] = rhs_s[u + 8 * threadIdx.y][(si + threadIdx.x) % 32];\n            }\n\n            for (int ii = 0; ii < 8; ++ii) {\n                float lhs = lhs_s[ii + 8 * threadIdx.x][(si + threadIdx.x) % 32];\n                for (int ji = 0; ji < 8; ++ji) {\n                    vals[ii][ji] += lhs * rhs[ji];\n                }\n            }\n        }\n    }\n\n    // don't write above the diagonal\n    if (j_base > i_base)\n        return;\n\n    for (int ii = 0; ii < 8; ++ii) {\n        for (int ji = 0; ji < 8; ji += 4) {\n            int i = i_base + ii;\n            int j = j_base + ji;\n            float4 result;\n            result.x = vals[ii][ji + 0] * alpha;\n            result.y = vals[ii][ji + 1] * alpha;\n            result.z = vals[ii][ji + 2] * alpha;\n            result.w = vals[ii][ji + 3] * alpha;\n            st_vec(p + i * PS + j, result);\n        }\n    }\n}\n\n/*                     ** Chapter VI - Competition Day **\n *\n * Finally, you feel ready to take on Cublas. You hand out tickets to the event for you friends to see.\n *\n *    ---------------------------------------------------------------------------------\n *    |           CuBLAS vs TriMul - Fight of the Century                             |\n *    |                                                                               |\n *    |   Ticket code:                                                                |\n *    |   > nvcc -O3 --use_fast_math trimat_forward.cu -o trimat_forward -lcublas     |\n *    |   > ./trimat 4                                                                |\n *    |                                                                               |\n *    ---------------------------------------------------------------------------------\n */\n\nvoid trimul_gpu(int kernel_num,\n                float* out,  const float* inp,\n                int B, int T, int C, int NH) {\n    switch (kernel_num) {\n        case 0:\n            trimul_cublas(out, inp, B, T, C, NH);\n            break;\n        case 1:\n            trimul_launcher<matmul_tri_naive>(out, inp, B, T, C, NH);\n            break;\n        case 2:\n            trimul_launcher<matmul_tri_registers>(out, inp, B, T, C, NH);\n            break;\n        case 3:\n            trimul_launcher<matmul_tri3>(out, inp, B, T, C, NH);\n            break;\n        case 4:\n            trimul_launcher<matmul_tri4>(out, inp, B, T, C, NH);\n            break;\n        default:\n            printf(\"Invalid kernel number\\n\");\n            exit(1);\n    }\n}\n\n\n\nint main(int argc, char **argv) {\n    setup_main();\n\n    int B = 8;\n    int T = 1024;\n    int C = 768;\n    int NH = 12;\n\n    // create host memory of random numbers\n    float* out = (float*)malloc(B * NH * T * T * sizeof(float));\n    float* inp = make_random_float(B * T * 3 * C);\n\n    // move to GPU\n    float* d_out;\n    float* d_inp;\n    cudaCheck(cudaMalloc(&d_out, B * NH * T * T * sizeof(float)));\n    cudaCheck(cudaMalloc(&d_inp, B * T * 3 * C * sizeof(float)));\n    cudaCheck(cudaMemcpy(d_inp, inp, B * T * 3 * C * sizeof(float), cudaMemcpyHostToDevice));\n\n    // buffer for cublas\n    cudaCheck(cudaMalloc(&d_qkvr, B * T * 3 * C * sizeof(float)));\n\n    // read kernel_num from command line\n    int kernel_num = 1;\n    if (argc > 1) {\n        kernel_num = atoi(argv[1]);\n    }\n    printf(\"Using kernel %d\\n\", kernel_num);\n\n    // first check the correctness of the kernel\n    trimul_cpu(out, inp, B, T, C, NH);\n    trimul_gpu(kernel_num, d_out, d_inp, B, T, C, NH);\n    validate_result(d_out, out, \"out\", B * NH * T * T, 1e-4f);\n\n    printf(\"All results match. Starting benchmarks.\\n\\n\");\n\n    // benchmark speed of the kernel\n    int repeat_times = 100;\n\n    float elapsed_time = benchmark_kernel(repeat_times, trimul_gpu,\n                                          kernel_num, d_out, d_inp,\n                                          B, T, C, NH);\n\n\n    float cublas_time = benchmark_kernel(repeat_times, trimul_gpu,\n                                         0, d_out, d_inp,\n                                         B, T, C, NH);\n\n    printf(\"time %.2f ms vs %.2f ms for CuBLAS\\n\", elapsed_time, cublas_time);\n\n    // free memory\n    free(out);\n    free(inp);\n    cudaCheck(cudaFree(d_out));\n    cudaCheck(cudaFree(d_inp));\n    cudaCheck(cudaFree(d_qkvr));\n    cublasDestroy(cublas_handle);\n\n    return 0;\n}\n"
  },
  {
    "path": "dev/data/README.md",
    "content": "# dev/data organization\n\nThe idea is that each dataset has a .py file here in the root of `dev/data`, and each dataset then creates a directory here, and writes and caches anything inside that directory. So for example:\n\n- running `python tinystories.py` will create a directory `tinystories` with its .bin files inside it\n- running `python tinyshakespeare.py` will create a directory `tinyshakespeare` with its .bin files inside it\n\nAnd so on. This way we can nicely organize multiple datasets here, share common utilities between them, and then point the .py/.c code in the root of the project accordingly to these.\n\nNote: we support \"gpt-2\" and \"llama\" (llama 3 in particular) models and the above scripts will tokenize gpt-2 by default.\n"
  },
  {
    "path": "dev/data/data_common.py",
    "content": "\"\"\"\nCommon utilities for the datasets\n\"\"\"\n\nimport requests\nfrom tqdm import tqdm\nimport numpy as np\n\n\ndef download_file(url: str, fname: str, chunk_size=1024):\n    \"\"\"Helper function to download a file from a given url\"\"\"\n    resp = requests.get(url, stream=True)\n    total = int(resp.headers.get(\"content-length\", 0))\n    with open(fname, \"wb\") as file, tqdm(\n        desc=fname,\n        total=total,\n        unit=\"iB\",\n        unit_scale=True,\n        unit_divisor=1024,\n    ) as bar:\n        for data in resp.iter_content(chunk_size=chunk_size):\n            size = file.write(data)\n            bar.update(size)\n\n\nHEADERS_INFO = {\n    \"gpt-2\": {\n        \"magic\": 20240520,\n        \"version\": 1,\n        \"token_dtype\": np.uint16,\n    },\n    \"llama-3\": {\n        \"magic\": 20240801,\n        \"version\": 7,\n        \"token_dtype\": np.uint32,\n    },\n}\n\ndef write_datafile(filename, toks, model_desc=\"gpt-2\"):\n    \"\"\"\n    Saves token data as a .bin file, for reading in C.\n    - First comes a header with 256 int32s\n    - The tokens follow, each as uint16 (gpt-2) or uint32 (llama)\n    \"\"\"\n    assert len(toks) < 2**31, \"token count too large\" # ~2.1B tokens\n    assert model_desc in [\"gpt-2\", \"llama-3\"], f\"unknown model descriptor {model_desc}\"\n    info = HEADERS_INFO[model_desc]\n    # construct the header\n    header = np.zeros(256, dtype=np.int32) # header is always 256 int32 values\n    header[0] = info[\"magic\"]\n    header[1] = info[\"version\"]\n    header[2] = len(toks) # number of tokens after the 256*4 bytes of header\n    # construct the data (numpy array of tokens)\n    toks_np = np.array(toks, dtype=info[\"token_dtype\"])\n    # write to file\n    num_bytes = (256 * 4) + (len(toks) * toks_np.itemsize)\n    print(f\"writing {len(toks):,} tokens to {filename} ({num_bytes:,} bytes) in the {model_desc} format\")\n    with open(filename, \"wb\") as f:\n        f.write(header.tobytes())\n        f.write(toks_np.tobytes())\n\ndef write_evalfile(filename, datas):\n    \"\"\"\n    Saves eval data as a .bin file, for reading in C.\n    Used for multiple-choice style evals, e.g. HellaSwag and MMLU\n    - First comes a header with 256 int32s\n    - The examples follow, each example is a stream of uint16_t:\n        - <START_EXAMPLE> delimiter of 2**16-1, i.e. 65,535\n        - <EXAMPLE_BYTES>, bytes encoding this example, allowing efficient skip to next\n        - <EXAMPLE_INDEX>, the index of the example in the dataset\n        - <LABEL>, the index of the correct completion\n        - <NUM_COMPLETIONS>, indicating the number of completions (usually 4)\n        - <NUM><CONTEXT_TOKENS>, where <NUM> is the number of tokens in the context\n        - <NUM><COMPLETION_TOKENS>, repeated NUM_COMPLETIONS times\n    \"\"\"\n    # construct the header\n    header = np.zeros(256, dtype=np.int32)\n    header[0] = 20240522 # magic\n    header[1] = 1 # version\n    header[2] = len(datas) # number of examples\n    header[3] = 0 # reserved for longest_example_bytes, fill in later\n    # now write the individual examples\n    longest_example_bytes = 0 # in units of uint16s\n    full_stream = [] # the stream of uint16s, we'll write a single time at the end\n    assert len(datas) < 2**16, \"too many examples?\"\n    for idx, data in enumerate(datas):\n        stream = []\n        # header of the example\n        stream.append(2**16-1) # <START_EXAMPLE>\n        stream.append(0) # <EXAMPLE_BYTES> (fill in later)\n        stream.append(idx) # <EXAMPLE_INDEX>\n        stream.append(data[\"label\"]) # <LABEL>\n        ending_tokens = data[\"ending_tokens\"]\n        assert len(ending_tokens) == 4, \"expected 4 completions for now? can relax later\"\n        stream.append(len(ending_tokens)) # <NUM_COMPLETIONS>\n        # the (shared) context tokens\n        ctx_tokens = data[\"ctx_tokens\"]\n        assert all(0 <= t < 2**16-1 for t in ctx_tokens), \"bad context token\"\n        stream.append(len(ctx_tokens))\n        stream.extend(ctx_tokens)\n        # the completion tokens\n        for end_tokens in ending_tokens:\n            assert all(0 <= t < 2**16-1 for t in end_tokens), \"bad completion token\"\n            stream.append(len(end_tokens))\n            stream.extend(end_tokens)\n        # write to full stream\n        nbytes = len(stream)*2 # 2 bytes per uint16\n        assert nbytes < 2**16, \"example too large?\"\n        stream[1] = nbytes # fill in the <EXAMPLE_BYTES> field\n        longest_example_bytes = max(longest_example_bytes, nbytes)\n        full_stream.extend(stream)\n    # construct the numpy array\n    stream_np = np.array(full_stream, dtype=np.uint16)\n    # fill in the longest_example field\n    assert 0 < longest_example_bytes < 2**16, f\"bad longest_example\"\n    header[3] = longest_example_bytes\n    # write to file (for HellaSwag val this is 10,042 examples, 3.6MB file)\n    print(f\"writing {len(datas):,} examples to {filename}\")\n    with open(filename, \"wb\") as f:\n        f.write(header.tobytes())\n        f.write(stream_np.tobytes())\n"
  },
  {
    "path": "dev/data/edu_fineweb.sh",
    "content": "#!/bin/bash\n\n# Downloads the FineWeb-Edu 100B dataset, but in an already tokenized format in .bin files\n# Example: ./edu_fineweb.sh 100\n# would download 100 shards\n# Default is all shards\n# Make sure to run this from current directory, i.e. inside ./dev/data!\n\n# Check if MAX_SHARDS is provided as positional first arg, otherwise default to 1024\nif [ $# -eq 0 ]; then\n    MAX_SHARDS=1001\nelse\n    MAX_SHARDS=$1\nfi\n\nif [ $MAX_SHARDS -gt 1001 ]; then\n    MAX_SHARDS=1001\nfi\n\n# Base URLs\nTRAIN_BASE_URL=\"https://huggingface.co/datasets/karpathy/fineweb-edu-100B-gpt2-token-shards/resolve/main/edu_fineweb_train_\"\nVAL_URL=\"https://huggingface.co/datasets/karpathy/fineweb-edu-100B-gpt2-token-shards/resolve/main/edu_fineweb_val_000000.bin\"\n\n# Directory to save files\nSAVE_DIR=\"edu_fineweb100B\"\n\n# Create the directory if it doesn't exist\nmkdir -p \"$SAVE_DIR\"\n\ndownload() {\n    local FILE_URL=$1\n    local FILE_NAME=$(basename $FILE_URL | cut -d'?' -f1)\n    local FILE_PATH=\"${SAVE_DIR}/${FILE_NAME}\"\n    curl -s -L -o \"$FILE_PATH\" \"$FILE_URL\"\n    echo \"Downloaded $FILE_NAME to $SAVE_DIR\"\n}\n\n# Function to manage parallel jobs\nrun_in_parallel() {\n    local max_jobs=$1\n    shift\n    local commands=(\"$@\")\n    local job_count=0\n\n    for cmd in \"${commands[@]}\"; do\n        eval \"$cmd\" &\n        ((job_count++))\n        if (( job_count >= max_jobs )); then\n            wait -n\n            ((job_count--))\n        fi\n    done\n\n    # Wait for any remaining jobs to finish\n    wait\n}\n\n# Export the function so it's available in subshells\nexport -f download\n\n# Download the validation shard\ndownload \"$VAL_URL\" &\n\n# Generate train file shard download commands\ntrain_commands=()\nfor i in $(seq -f \"%06g\" 1 $MAX_SHARDS); do\n    FILE_URL=\"${TRAIN_BASE_URL}${i}.bin?download=true\"\n    train_commands+=(\"download \\\"$FILE_URL\\\"\")\ndone\n\n# Run the train file commands in parallel\nrun_in_parallel 40 \"${train_commands[@]}\"\necho \"The val shard and first $MAX_SHARDS train shards of FineWebEdu100B files downloaded in $SAVE_DIR\"\n"
  },
  {
    "path": "dev/data/fineweb.py",
    "content": "\"\"\"\nFineWeb dataset (for srs pretraining)\nhttps://huggingface.co/datasets/HuggingFaceFW/fineweb\n\nexample doc to highlight the structure of the dataset:\n{\n  \"text\": \"Posted by mattsmith on 20th April 2012\\nStraight from...\",\n  \"id\": \"<urn:uuid:d853d453-196e-4488-a411-efc2b26c40d2>\",\n  \"dump\": \"CC-MAIN-2013-20\",\n  \"url\": \"http://nleastchatter.com/philliesphandom/tag/freddy-galvis/\",\n  \"date\": \"2013-05-18T07:24:47Z\",\n  \"file_path\": \"s3://commoncrawl/long.../path.../file.gz\",\n  \"language\": \"en\",\n  \"language_score\": 0.9185474514961243,\n  \"token_count\": 594\n}\n\nExample of downloading the 100B dataset of FineWebEDU, from root directory:\npython dev/data/fineweb.py -t edu -v 100B\n100B runs for small few hours, depending on your internet and computer.\n\"\"\"\nimport os\nimport argparse\nimport multiprocessing as mp\n\nimport numpy as np\nimport tiktoken\nfrom datasets import load_dataset\nfrom tqdm import tqdm\n\nfrom transformers import AutoTokenizer\n\n\nfrom data_common import write_datafile\n# ------------------------------------------\n\nparser = argparse.ArgumentParser(description=\"FineWeb and Edu-FineWeb dataset preprocessing\")\nparser.add_argument(\"-t\", \"--type\", type=str, default=\"classic\", help=\"Fineweb type, edu|classic\")\nparser.add_argument(\"-v\", \"--version\", type=str, default=\"10B\", help=\"Fineweb data sample size, 10B|100B\")\nparser.add_argument(\"-m\", \"--model_desc\", type=str, default=\"gpt-2\", help=\"Model descriptor, gpt-2|llama-3\")\nparser.add_argument(\"-s\", \"--shard_size\", type=int, default=10**8, help=\"Size of each data shard in the output .bin files, in tokens\")\nargs = parser.parse_args()\n\n# FineWeb has a few possible subsamples available\nassert args.version in {\"10B\", \"100B\"}, \"version must be one of: 10B, 100B\"\nassert args.type in {\"edu\", \"classic\"}, \"type must be one of: edu, classic\"\ndirectories = {\n    (\"classic\", \"10B\"): (\"fineweb10B\", \"sample-10BT\"),\n    (\"classic\", \"100B\"): (\"fineweb100B\", \"sample-100BT\"),\n    (\"edu\", \"10B\"): (\"edu_fineweb10B\", \"sample-10BT\"),\n    (\"edu\", \"100B\"): (\"edu_fineweb100B\", \"sample-100BT\")\n}\nlocal_dir, remote_name = directories[(args.type, args.version)]\n\n# create the cache the local directory if it doesn't exist yet\nDATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir)\nos.makedirs(DATA_CACHE_DIR, exist_ok=True)\n\n# download the dataset\nif args.type == \"classic\":\n    fw = load_dataset(\"HuggingFaceFW/fineweb\", name=remote_name, split=\"train\")\n    name = \"fineweb\"\nelif args.type ==\"edu\":\n    fw = load_dataset(\"HuggingFaceFW/fineweb-edu\", name=remote_name, split=\"train\")\n    name = \"edu_fineweb\"\n\ndef tokenize_llama(doc):\n    # tokenizes a single document and returns a numpy array of uint32 tokens\n    tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B\")\n    encode = lambda s: tokenizer.encode(s, add_special_tokens=False, verbose=False, split_special_tokens=True)\n    eot = tokenizer.encode('')[0] # by default the tokenizer adds the EOT token (128000)\n    tokens = [eot] # the special <|endoftext|> token delimits all documents\n    tokens.extend(encode(doc[\"text\"]))\n    tokens_np = np.array(tokens)\n    assert (0 <= tokens_np).all() and (tokens_np < 2**32).all(), \"token dictionary too large for uint32\"\n    tokens_np_uint = tokens_np.astype(np.uint32)\n    return tokens_np_uint\n\ndef tokenize_gpt2(doc):\n    # tokenizes a single document and returns a numpy array of uint16 tokens\n    enc = tiktoken.get_encoding(\"gpt2\")\n    encode = lambda s: enc.encode_ordinary(s)\n    eot = enc._special_tokens['<|endoftext|>'] # end of text token\n    tokens = [eot] # the special <|endoftext|> token delimits all documents\n    tokens.extend(encode(doc[\"text\"]))\n    tokens_np = np.array(tokens)\n    assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), \"token dictionary too large for uint16\"\n    tokens_np_uint = tokens_np.astype(np.uint16)\n    return tokens_np_uint\n\ntoken_dtype = {\n    \"gpt-2\": np.uint16,\n    \"llama-3\": np.uint32\n}[args.model_desc]\n\n# tokenize all documents and write output shards, each of shard_size tokens (last shard has remainder)\nnprocs = max(1, os.cpu_count() - 2) # don't hog the entire system\nwith mp.Pool(nprocs) as pool:\n    shard_index = 0\n    # preallocate buffer to hold current shard\n    all_tokens_np = np.empty((args.shard_size,), dtype=token_dtype)\n    token_count = 0\n    progress_bar = None\n\n    tokenize = lambda x: None\n    if args.model_desc == \"gpt-2\":\n        tokenize = tokenize_gpt2\n    elif args.model_desc == \"llama-3\":\n        tokenize = tokenize_llama\n    else:\n        raise ValueError(f\"unknown model {args.model_desc}\")\n\n    for tokens in pool.imap(tokenize, fw, chunksize=16):\n\n        # is there enough space in the current shard for the new tokens?\n        if token_count + len(tokens) < args.shard_size:\n            # simply append tokens to current shard\n            all_tokens_np[token_count:token_count+len(tokens)] = tokens\n            token_count += len(tokens)\n            # update progress bar\n            if progress_bar is None:\n                progress_bar = tqdm(total=args.shard_size, unit=\"tokens\", desc=f\"Shard {shard_index}\")\n            progress_bar.update(len(tokens))\n        else:\n            # write the current shard and start a new one\n            split = \"val\" if shard_index == 0 else \"train\"\n            filename = os.path.join(DATA_CACHE_DIR, f\"{name}_{split}_{shard_index:06d}.bin\")\n            # split the document into whatever fits in this shard; the remainder goes to next one\n            remainder = args.shard_size - token_count\n            progress_bar.update(remainder)\n            all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]\n            write_datafile(filename, all_tokens_np.tolist(), args.model_desc)\n            shard_index += 1\n            progress_bar = None\n            # populate the next shard with the leftovers of the current doc\n            all_tokens_np[0:len(tokens)-remainder] = tokens[remainder:]\n            token_count = len(tokens)-remainder\n\n    # write any remaining tokens as the last shard\n    if token_count != 0:\n        split = \"val\" if shard_index == 0 else \"train\"\n        filename = os.path.join(DATA_CACHE_DIR, f\"{name}_{split}_{shard_index:06d}.bin\")\n        write_datafile(filename, (all_tokens_np[:token_count]).tolist(), args.model_desc)\n"
  },
  {
    "path": "dev/data/fineweb.sh",
    "content": "#!/bin/bash\n\n# Downloads the FineWeb100B dataset, but in an already tokenized format in .bin files\n# Example: ./fineweb.sh 100\n# would download 100 shards\n# Default is all shards\n\n# Check if MAX_SHARDS is provided as positional first arg, otherwise default to 1024\nif [ $# -eq 0 ]; then\n    MAX_SHARDS=1028\nelse\n    MAX_SHARDS=$1\nfi\n\n# Ensure MAX_SHARDS is not greater than 1028\nif [ $MAX_SHARDS -gt 1028 ]; then\n    MAX_SHARDS=1028\nfi\n\n# Base URLs\nTRAIN_BASE_URL=\"https://huggingface.co/datasets/chrisdryden/FineWebTokenizedGPT2/resolve/main/fineweb_train_\"\nVAL_URL=\"https://huggingface.co/datasets/chrisdryden/FineWebTokenizedGPT2/resolve/main/fineweb_val_000000.bin?download=true\"\n\n# Directory to save files\nSAVE_DIR=\"fineweb100B\"\n\n# Create the directory if it doesn't exist\nmkdir -p \"$SAVE_DIR\"\n\n# Function to download, decompress, and delete files\ndownload() {\n    local FILE_URL=$1\n    local FILE_NAME=$(basename $FILE_URL | cut -d'?' -f1)\n    local FILE_PATH=\"${SAVE_DIR}/${FILE_NAME}\"\n\n    # Download the file\n    curl -s -L -o \"$FILE_PATH\" \"$FILE_URL\"\n    echo \"Downloaded $FILE_NAME to $SAVE_DIR\"\n}\n\n# Function to manage parallel jobs\nrun_in_parallel() {\n    local max_jobs=$1\n    shift\n    local commands=(\"$@\")\n    local job_count=0\n\n    for cmd in \"${commands[@]}\"; do\n        eval \"$cmd\" &\n        ((job_count++))\n        if (( job_count >= max_jobs )); then\n            wait -n\n            ((job_count--))\n        fi\n    done\n\n    # Wait for any remaining jobs to finish\n    wait\n}\n\n# Export the function so it's available in subshells\nexport -f download\n\n# Download\ndownload \"$VAL_URL\" &\n\n# Generate train file commands\ntrain_commands=()\nfor i in $(seq -f \"%06g\" 1 $MAX_SHARDS); do\n    FILE_URL=\"${TRAIN_BASE_URL}${i}.bin?download=true\"\n    train_commands+=(\"download \\\"$FILE_URL\\\"\")\ndone\n\n# Run the train file commands in parallel\nrun_in_parallel 40 \"${train_commands[@]}\"\n\necho \"The val shard and first $MAX_SHARDS train shards of FineWeb100B files downloaded in $SAVE_DIR\"\n"
  },
  {
    "path": "dev/data/hellaswag.py",
    "content": "\"\"\"\nDownloads and evaluates HellaSwag in Python.\nThis then acts as the reference file for llm.c\nAlso writes the data (tokens, labels) to .bin files for parallel evaluation in C.\nhttps://github.com/rowanz/hellaswag\n\nExample HellaSwag json item:\n\n{\"ind\": 24, \"activity_label\": \"Roof shingle removal\", \"ctx_a\": \"A man is sitting on a roof.\", \"ctx_b\": \"he\", \"ctx\": \"A man is sitting on a roof. he\", \"split\": \"val\", \"split_type\": \"indomain\", \"label\": 3, \"endings\": [\"is using wrap to wrap a pair of skis.\", \"is ripping level tiles off.\", \"is holding a rubik's cube.\", \"starts pulling up roofing on a roof.\"], \"source_id\": \"activitynet~v_-JhWjGDPHMY\"}\n\nind: dataset ID\nactivity_label: The ActivityNet or WikiHow label for this example\ncontext: There are two formats. The full context is in ctx. When the context ends in an (incomplete) noun phrase, like for ActivityNet, this incomplete noun phrase is in ctx_b, and the context up until then is in ctx_a. This can be useful for models such as BERT that need the last sentence to be complete. However, it's never required. If ctx_b is nonempty, then ctx is the same thing as ctx_a, followed by a space, then ctx_b.\nendings: a list of 4 endings. The correct index is given by label (0,1,2, or 3)\nsplit: train, val, or test.\nsplit_type: indomain if the activity label is seen during training, else zeroshot\nsource_id: Which video or WikiHow article this example came from\n\ngpt2 (124M)\n- eleuther harness reports acc 28.92%, acc_norm 31.14% (multiple choice style)\n- this script: 10042 acc: 0.2859 acc_norm: 0.2955 (completion style)\n\ngpt2-xl (1558M)\n- eleuther harness reports acc 40.04%, acc_norm 50.89% (multiple choice style)\n- this script: 10042 acc: 0.3842 acc_norm: 0.4893 (completion style)\n\nThe validation set of HellaSwag has a total of 10,042 examples.\n\"\"\"\n\nimport os\nimport json\nimport requests\nimport tiktoken\nfrom tqdm import tqdm\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nfrom transformers import GPT2LMHeadModel\nfrom data_common import download_file, write_evalfile\n\n# -----------------------------------------------------------------------------\nDATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), \"hellaswag\")\n\nhellaswags = {\n    \"train\": \"https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_train.jsonl\",\n    \"val\": \"https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl\",\n    \"test\": \"https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_test.jsonl\",\n}\n\nenc = tiktoken.get_encoding(\"gpt2\")\n\ndef download(split):\n    \"\"\"Downloads HellaSwag DATA_CACHE_DIR\"\"\"\n    os.makedirs(DATA_CACHE_DIR, exist_ok=True)\n    data_url = hellaswags[split]\n    data_filename = os.path.join(DATA_CACHE_DIR, f\"hellaswag_{split}.jsonl\")\n    if not os.path.exists(data_filename):\n        print(f\"Downloading {data_url} to {data_filename}...\")\n        download_file(data_url, data_filename)\n    else:\n        print(f\"{data_filename} already exists, skipping download...\")\n\ndef render_example(example):\n    \"\"\"\n    Given the example as a dictionary, render it as three torch tensors:\n    - tokens (the tokens of context + completion, of size 4xN, as there are always 4 candidates)\n    - mask (is 1 in the region of the candidate completion, where we evaluate likelihoods)\n    - label (the index of the correct completion, which we hope has the highest likelihood)\n    \"\"\"\n    ctx = example[\"ctx\"]\n    label = example[\"label\"]\n    endings = example[\"endings\"]\n\n    # data needed to reproduce this eval on the C size\n    data = {\n        \"label\": label,\n        \"ctx_tokens\": None,\n        \"ending_tokens\": [],\n    }\n\n    # gather up all the tokens\n    ctx_tokens = enc.encode(ctx)\n    data[\"ctx_tokens\"] = ctx_tokens\n    tok_rows = []\n    mask_rows = []\n    for end in endings:\n        end_tokens = enc.encode(\" \" + end) # note: prepending \" \" because GPT-2 tokenizer\n        tok_rows.append(ctx_tokens + end_tokens)\n        mask_rows.append([0]*len(ctx_tokens) + [1]*len(end_tokens))\n        data[\"ending_tokens\"].append(end_tokens)\n\n    # have to be careful during the collation because the number of tokens in each row can differ\n    max_len = max(len(row) for row in tok_rows)\n    tokens = torch.zeros((4, max_len), dtype=torch.long)\n    mask = torch.zeros((4, max_len), dtype=torch.long)\n    for i, (tok_row, mask_row) in enumerate(zip(tok_rows, mask_rows)):\n        tokens[i, :len(tok_row)] = torch.tensor(tok_row)\n        mask[i, :len(mask_row)] = torch.tensor(mask_row)\n\n    return data, tokens, mask, label\n\ndef iterate_examples(split):\n    # there are 10,042 examples in total in val\n    download(split)\n    with open(os.path.join(DATA_CACHE_DIR, f\"hellaswag_{split}.jsonl\"), \"r\") as f:\n        for line in f:\n            example = json.loads(line)\n            yield example\n\n@torch.no_grad()\ndef evaluate(model_type, device):\n\n    torch.set_float32_matmul_precision('high') # use tf32\n\n    model = GPT2LMHeadModel.from_pretrained(model_type)\n    model.to(device)\n    # model = torch.compile(model)\n\n    datas = []\n    num_correct_norm = 0\n    num_correct = 0\n    num_total = 0\n    for example in iterate_examples(\"val\"):\n        data, tokens, mask, label = render_example(example)\n        datas.append(data)\n        tokens = tokens.to(device)\n        mask = mask.to(device)\n\n        # get the logits\n        logits = model(tokens).logits\n        # evaluate the autoregressive loss at all positions\n        shift_logits = (logits[..., :-1, :]).contiguous()\n        shift_tokens = (tokens[..., 1:]).contiguous()\n        flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))\n        flat_shift_tokens = shift_tokens.view(-1)\n        shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')\n        shift_losses = shift_losses.view(tokens.size(0), -1)\n        # now get the average loss just for the completion region (where mask == 1), in each row\n        shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token\n        masked_shift_losses = shift_losses * shift_mask\n        # sum and divide by the number of 1s in the mask\n        sum_loss = masked_shift_losses.sum(dim=1)\n        avg_loss = sum_loss / shift_mask.sum(dim=1)\n        # now we have a loss for each of the 4 completions\n        # the one with the lowest loss should be the most likely\n        pred = sum_loss.argmin().item()\n        pred_norm = avg_loss.argmin().item()\n\n        # accumulate stats\n        num_total += 1\n        num_correct += int(pred == label)\n        num_correct_norm += int(pred_norm == label)\n        print(f\"{num_total} acc: {num_correct/num_total:.4f} acc_norm: {num_correct_norm}/{num_total}={num_correct_norm/num_total:.4f}\")\n\n        # debug: pretty print a few examples, and the losses in each case\n        if num_total < 10:\n            print(\"---\")\n            print(f\"Context:\\n {example['ctx']}\")\n            print(f\"Endings:\")\n            for i, end in enumerate(example[\"endings\"]):\n                print(f\"{i} (loss: {avg_loss[i].item():.4f}) {end}\")\n            print(f\"predicted: {pred_norm}, actual: {label}\")\n\n    # now write the data to a .bin file\n    filename = os.path.join(DATA_CACHE_DIR, f\"hellaswag_val.bin\")\n    write_evalfile(filename, datas)\n\nif __name__ == \"__main__\":\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-m\", \"--model_type\", type=str, default=\"gpt2\", help=\"the model type to use\")\n    parser.add_argument(\"-d\", \"--device\", type=str, default=\"cuda\", help=\"the device to use\")\n    args = parser.parse_args()\n    evaluate(args.model_type, args.device)\n"
  },
  {
    "path": "dev/data/mmlu.py",
    "content": "\"\"\"\nDownloads and evaluates MMLU in Python.\nThis then acts as the reference file for llm.c\nhttps://github.com/hendrycks/test\n\ngpt2 (124M)\n- this script: 14042 acc: 0.2557 acc_norm: 0.2721\n\ngpt2-xl (1558M)\n- this script: 14042 acc: 0.2927 acc_norm: 0.3035\n\"\"\"\n\nimport os\nimport requests\nimport tiktoken\nimport pandas as pd\nfrom tqdm import tqdm\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nfrom transformers import GPT2LMHeadModel\nfrom data_common import download_file\n\n# -----------------------------------------------------------------------------\nDATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), \"mmlu\")\n\nenc = tiktoken.get_encoding(\"gpt2\")\ndata_url = \"https://people.eecs.berkeley.edu/~hendrycks/data.tar\"\n\ndef download():\n    \"\"\"Downloads MMLU to DATA_CACHE_DIR\"\"\"\n    os.makedirs(DATA_CACHE_DIR, exist_ok=True)\n    data_filename = os.path.join(DATA_CACHE_DIR, f\"data.tar\")\n    if not os.path.exists(data_filename):\n        print(f\"Downloading {data_url} to {data_filename}...\")\n        download_file(data_url, data_filename)\n        os.system(f\"tar -xf {data_filename} -C {DATA_CACHE_DIR}\") # untar\n        # creates a directory \"data\" inside it, with e.g. data/test/*csv\n    else:\n        print(f\"{data_filename} already exists, skipping download...\")\n\ndef iterate_examples():\n    # there are 14,042 examples in total in the test set\n\n    download()\n    test_dir = os.path.join(DATA_CACHE_DIR, \"data\", \"test\")\n    csv_files = [f for f in os.listdir(test_dir) if f.endswith(\".csv\")]\n    for csv_file in csv_files:\n        csv_path = os.path.join(test_dir, csv_file)\n        print(csv_path)\n        df = pd.read_csv(csv_path, header=None)\n        n = df.shape[0]\n        for idx in range(n):\n            example = {\n                \"question\": df.iloc[idx, 0],\n                \"endings\": [df.iloc[idx, 1], df.iloc[idx, 2], df.iloc[idx, 3], df.iloc[idx, 4]],\n                \"label\": df.iloc[idx, 5],\n            }\n            yield example\n\ndef render_example(example):\n    \"\"\"\n    Given the example as a dictionary, render it as three torch tensors:\n    - tokens (the tokens of context + completion, of size 4xN, as there are always 4 candidates)\n    - mask (is 1 in the region of the candidate completion, where we evaluate likelihoods)\n    - label (the index of the correct completion, which we hope has the highest likelihood)\n    \"\"\"\n    ctx = f\"Question: {example['question']}\\n\\nAnswer:\"\n    ctx_tokens = enc.encode(ctx)\n\n    tok_rows = []\n    mask_rows = []\n    for end in example[\"endings\"]:\n        end_tokens = enc.encode(\" \" + str(end)) # note: prepending \" \" because GPT-2 tokenizer\n        tok_rows.append(ctx_tokens + end_tokens)\n        mask_rows.append([0]*len(ctx_tokens) + [1]*len(end_tokens))\n\n    # have to be careful during the collation because the number of tokens in each row can differ\n    max_len = max(len(row) for row in tok_rows)\n    tokens = torch.zeros((4, max_len), dtype=torch.long)\n    mask = torch.zeros((4, max_len), dtype=torch.long)\n    for i, (tok_row, mask_row) in enumerate(zip(tok_rows, mask_rows)):\n        tokens[i, :len(tok_row)] = torch.tensor(tok_row)\n        mask[i, :len(mask_row)] = torch.tensor(mask_row)\n\n    label = \"ABCD\".index(example[\"label\"])\n    return tokens, mask, label\n\n@torch.no_grad()\ndef evaluate(model_type, device):\n\n    torch.set_float32_matmul_precision('high') # use tf32\n\n    model = GPT2LMHeadModel.from_pretrained(model_type)\n    model.to(device)\n    # model = torch.compile(model)\n\n    num_correct_norm = 0\n    num_correct = 0\n    num_total = 0\n    for example in iterate_examples():\n        tokens, mask, label = render_example(example)\n        tokens = tokens.to(device)\n        mask = mask.to(device)\n\n        # get the logits\n        logits = model(tokens).logits\n        # evaluate the autoregressive loss at all positions\n        shift_logits = (logits[..., :-1, :]).contiguous()\n        shift_tokens = (tokens[..., 1:]).contiguous()\n        flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))\n        flat_shift_tokens = shift_tokens.view(-1)\n        shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')\n        shift_losses = shift_losses.view(tokens.size(0), -1)\n        # now get the average loss just for the completion region (where mask == 1), in each row\n        shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token\n        masked_shift_losses = shift_losses * shift_mask\n        # sum and divide by the number of 1s in the mask\n        sum_loss = masked_shift_losses.sum(dim=1)\n        avg_loss = sum_loss / shift_mask.sum(dim=1)\n        # now we have a loss for each of the 4 completions\n        # the one with the lowest loss should be the most likely\n        pred = sum_loss.argmin().item()\n        pred_norm = avg_loss.argmin().item()\n\n        # accumulate stats\n        num_total += 1\n        num_correct += int(pred == label)\n        num_correct_norm += int(pred_norm == label)\n        print(f\"{num_total} acc: {num_correct/num_total:.4f} acc_norm: {num_correct_norm/num_total:.4f}\")\n\n        # debug prints\n        if num_total < 10:\n            print(\"---\")\n            print(f\"Context:\\n {example['question']}\")\n            print(f\"Endings:\")\n            for i, end in enumerate(example[\"endings\"]):\n                print(f\"{i} (loss: {avg_loss[i].item():.4f}) {end}\")\n            print(f\"predicted: {pred}, actual: {label}\")\n\nif __name__ == \"__main__\":\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"-m\", \"--model_type\", type=str, default=\"gpt2\", help=\"the model type to use\")\n    parser.add_argument(\"-d\", \"--device\", type=str, default=\"cuda\", help=\"the device to use\")\n    args = parser.parse_args()\n    evaluate(args.model_type, args.device)\n"
  },
  {
    "path": "dev/data/tinyshakespeare.py",
    "content": "\"\"\"\nDownloads and tokenizes the TinyShakespeare dataset.\n- The download is from Github.\n- The tokenization is GPT-2 tokenizer with tiktoken\n\nThe output is written to a newly created tinyshakespeare/ folder.\nThe script prints:\n\nFor GPT-2:\n$ python dev/data/tinyshakespeare.py --model=gpt-2\nwriting 32,768 tokens to /home/ubuntu/llm.c/dev/data/tinyshakespeare/tiny_shakespeare_val.bin (66,560 bytes) in the gpt-2 format\nwriting 305,260 tokens to /home/ubuntu/llm.c/dev/data/tinyshakespeare/tiny_shakespeare_train.bin (611,544 bytes) in the gpt-2 format\n\nFor LLaMA 3:\n$ python dev/data/tinyshakespeare.py --model=llama-3\nwriting 32,768 tokens to /home/ubuntu/llm.c/dev/data/tinyshakespeare/tiny_shakespeare_val.bin (132,096 bytes) in the llama-3 format\nwriting 276,224 tokens to /home/ubuntu/llm.c/dev/data/tinyshakespeare/tiny_shakespeare_train.bin (1,105,920 bytes) in the llama-3 format\n\nAnd runs in a few seconds depending on your internet\nconnection and computer. The .bin files are raw byte\nstreams of uint16 (gpt-2) or uint32 (llama) numbers indicating the token ids.\n\"\"\"\n\nimport argparse\nimport os\n\nimport tiktoken\nfrom transformers import AutoTokenizer\n\nfrom data_common import download_file, write_datafile\n\n# -----------------------------------------------------------------------------\nDATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), \"tinyshakespeare\")\n\ndef download():\n    \"\"\"Downloads the TinyShakespeare dataset to DATA_CACHE_DIR\"\"\"\n    os.makedirs(DATA_CACHE_DIR, exist_ok=True)\n    # download the TinyShakespeare dataset, unless it's already downloaded\n    data_url = \"https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\"\n    data_filename = os.path.join(DATA_CACHE_DIR, \"tiny_shakespeare.txt\")\n    if not os.path.exists(data_filename):\n        print(f\"Downloading {data_url} to {data_filename}...\")\n        download_file(data_url, data_filename)\n    else:\n        print(f\"{data_filename} already exists, skipping download...\")\n\ndef tokenize(model_desc):\n    if model_desc == \"gpt-2\":\n        enc = tiktoken.get_encoding(\"gpt2\")\n        encode = lambda s: enc.encode_ordinary(s)\n        eot = enc._special_tokens['<|endoftext|>'] # end of text token\n    elif model_desc == \"llama-3\":\n        tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B\")\n        encode = lambda s: tokenizer.encode(s, add_special_tokens=False, verbose=False, split_special_tokens=True)\n        eot = tokenizer.encode('')[0] # by default the tokenizer adds the EOT token (128000)\n    else:\n        raise ValueError(f\"unknown model descriptor {model_desc}\")\n    data_filename = os.path.join(DATA_CACHE_DIR, \"tiny_shakespeare.txt\")\n    text = open(data_filename, 'r').read()\n    # let's treat every individual chunk of text as a separate \"document\"\n    sections = text.split(\"\\n\\n\")\n    tokens = []\n    for i, s in enumerate(sections):\n        tokens.append(eot)\n        # there was a mild bug where I originally intended to remove \\n\\n, but instead just added\n        # the EOT right after each \\n\\n, so I'm keeping that behavior for backwards compatibility\n        # therefore we have to here add an extra \\n\\n at the end of each section, except the last\n        spad = s + \"\\n\\n\" if i != len(sections) - 1 else s\n        tokens.extend(encode(spad))\n    # let's take the first 32,768 tokens as the validation split (~10%)\n    val_tokens = tokens[:32768]\n    train_tokens = tokens[32768:]\n    # save to file\n    val_filename = os.path.join(DATA_CACHE_DIR, \"tiny_shakespeare_val.bin\")\n    train_filename = os.path.join(DATA_CACHE_DIR, \"tiny_shakespeare_train.bin\")\n    write_datafile(val_filename, val_tokens, model_desc)\n    write_datafile(train_filename, train_tokens, model_desc)\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Tiny Shakespeare dataset preprocessing\")\n    parser.add_argument(\"-m\", \"--model_desc\", type=str, default=\"gpt-2\", choices=[\"gpt-2\", \"llama-3\"], help=\"Model type, gpt-2|llama-3\")\n    args = parser.parse_args()\n    download()\n    tokenize(args.model_desc)\n"
  },
  {
    "path": "dev/data/tinystories.py",
    "content": "\"\"\"\nDownloads and tokenizes the TinyStories dataset.\n- The download is from HuggingFace datasets.\n- The tokenization is using either GPT-2 or LLaMA 3 tokenizer.\n\nThe output is written to a newly created tinystories/ folder.\nThe script prints:\n\nFor GPT-2:\nNumber of shards: 50\nTokenizing val split...\nwriting 19,043,638 tokens to tinystories/TinyStories_val.bin\nTokenizing train split...\nwriting 925,653,391 tokens to tinystories/TinyStories_train.bin\n\nFor LLaMA 3:\nNumber of shards: 50\nTokenizing val split...\nwriting 18,660,516 tokens to tinystories/TinyStories_val.bin\nTokenizing train split...\nwriting 907,021,844 tokens to tinystories/TinyStories_train.bin\n\nAnd runs in few minutes two depending on your internet\nconnection and computer. The .bin files are raw byte\nstreams of uint16 (gpt-2) or uint32 (llama) numbers indicating the token ids.\n\"\"\"\n\nimport argparse\nimport os\nimport glob\nimport json\nimport random\nfrom concurrent.futures import ProcessPoolExecutor, as_completed\n\nimport tiktoken\nfrom transformers import AutoTokenizer\n\nfrom data_common import download_file, write_datafile\n\n# -----------------------------------------------------------------------------\nDATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), \"tinystories\")\n\ndef download():\n    \"\"\"Downloads the TinyStories dataset to DATA_CACHE_DIR\"\"\"\n    os.makedirs(DATA_CACHE_DIR, exist_ok=True)\n\n    # download the TinyStories dataset, unless it's already downloaded\n    data_url = \"https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz\"\n    data_filename = os.path.join(DATA_CACHE_DIR, \"TinyStories_all_data.tar.gz\")\n    if not os.path.exists(data_filename):\n        print(f\"Downloading {data_url} to {data_filename}...\")\n        download_file(data_url, data_filename)\n    else:\n        print(f\"{data_filename} already exists, skipping download...\")\n\n    # unpack the tar.gz file into all the data shards (json files)\n    data_dir = os.path.join(DATA_CACHE_DIR, \"TinyStories_all_data\")\n    if not os.path.exists(data_dir):\n        os.makedirs(data_dir, exist_ok=True)\n        print(f\"Unpacking {data_filename}...\")\n        os.system(f\"tar -xzf {data_filename} -C {data_dir}\")\n    else:\n        print(f\"{data_dir} already exists, skipping unpacking...\")\n\n    # print a single example just for debugging and such\n    shard_filenames = sorted(glob.glob(os.path.join(data_dir, \"*.json\")))\n    print(\"Download done.\")\n    print(f\"Number of shards: {len(shard_filenames)}\")\n    # with open(shard_filenames[0], \"r\") as f:\n    #     data = json.load(f)\n    # print(f\"Example story:\\n{data[0]}\")\n\ndef process_shard(shard_index, shard_filename, model_desc):\n    if model_desc == \"gpt-2\":\n        enc = tiktoken.get_encoding(\"gpt2\")\n        encode = lambda s: enc.encode_ordinary(s)\n        eot = enc._special_tokens['<|endoftext|>'] # end of text token\n    elif model_desc == \"llama-3\":\n        tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B\")\n        encode = lambda s: tokenizer.encode(s, add_special_tokens=False, verbose=False, split_special_tokens=True)\n        eot = tokenizer.encode('')[0] # by default the tokenizer adds the EOT token (128000)\n    else:\n        raise ValueError(f\"unknown model descriptor {model_desc}\")\n\n    with open(shard_filename, \"r\") as f:\n        data = json.load(f)\n    rng = random.Random(1337 + shard_index)\n    rng.shuffle(data)\n    all_tokens = []\n    for example in data:\n        text = example[\"story\"]\n        text = text.strip()  # get rid of leading/trailing whitespace\n        tokens = encode(text)\n        all_tokens.append(eot)\n        all_tokens.extend(tokens)\n    return all_tokens\n\ndef tokenize(model_desc):\n    # shard 0 will be the val split, rest is train\n    data_dir = os.path.join(DATA_CACHE_DIR, \"TinyStories_all_data\")\n    shard_filenames = sorted(glob.glob(os.path.join(data_dir, \"*.json\")))\n    val_shards = [shard_filenames[0]]\n    train_shards = shard_filenames[1:]\n    for split_name, split_shards in [(\"val\", val_shards), (\"train\", train_shards)]:\n\n        print(f\"Tokenizing {split_name} split...\")\n        all_tokens = []\n        with ProcessPoolExecutor() as executor:\n            futures = [executor.submit(process_shard, shard_index, shard_filename, model_desc)\n                       for shard_index, shard_filename in enumerate(split_shards)]\n            for future in as_completed(futures):\n                all_tokens.extend(future.result())\n\n        split_filename = os.path.join(DATA_CACHE_DIR, f\"TinyStories_{split_name}.bin\")\n        write_datafile(split_filename, all_tokens, model_desc)\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Tiny Stories dataset preprocessing\")\n    parser.add_argument(\"-m\", \"--model_desc\", type=str, default=\"gpt-2\", choices=[\"gpt-2\", \"llama-3\"], help=\"Model type, gpt-2|llama-3\")\n    args = parser.parse_args()\n    download()\n    tokenize(args.model_desc)"
  },
  {
    "path": "dev/download_starter_pack.sh",
    "content": "#!/bin/bash\n\n# Get the directory of the script\nSCRIPT_DIR=$(dirname \"$(realpath \"$0\")\")\n\n# Base URL\nBASE_URL=\"https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/\"\n\n# Directory paths based on script location\nSAVE_DIR_PARENT=\"$SCRIPT_DIR/..\"\nSAVE_DIR_TINY=\"$SCRIPT_DIR/data/tinyshakespeare\"\nSAVE_DIR_HELLA=\"$SCRIPT_DIR/data/hellaswag\"\n\n# Create the directories if they don't exist\nmkdir -p \"$SAVE_DIR_TINY\"\nmkdir -p \"$SAVE_DIR_HELLA\"\n\n# Files to download\nFILES=(\n    \"gpt2_124M.bin\"\n    \"gpt2_124M_bf16.bin\"\n    \"gpt2_124M_debug_state.bin\"\n    \"gpt2_tokenizer.bin\"\n    \"tiny_shakespeare_train.bin\"\n    \"tiny_shakespeare_val.bin\"\n    \"hellaswag_val.bin\"\n)\n\n# Function to download files to the appropriate directory\ndownload_file() {\n    local FILE_NAME=$1\n    local FILE_URL=\"${BASE_URL}${FILE_NAME}?download=true\"\n    local FILE_PATH\n\n    # Determine the save directory based on the file name\n    if [[ \"$FILE_NAME\" == tiny_shakespeare* ]]; then\n        FILE_PATH=\"${SAVE_DIR_TINY}/${FILE_NAME}\"\n    elif [[ \"$FILE_NAME\" == hellaswag* ]]; then\n        FILE_PATH=\"${SAVE_DIR_HELLA}/${FILE_NAME}\"\n    else\n        FILE_PATH=\"${SAVE_DIR_PARENT}/${FILE_NAME}\"\n    fi\n\n    # Download the file\n    curl -s -L -o \"$FILE_PATH\" \"$FILE_URL\"\n    echo \"Downloaded $FILE_NAME to $FILE_PATH\"\n}\n\n# Export the function so it's available in subshells\nexport -f download_file\n\n# Generate download commands\ndownload_commands=()\nfor FILE in \"${FILES[@]}\"; do\n    download_commands+=(\"download_file \\\"$FILE\\\"\")\ndone\n\n# Function to manage parallel jobs in increments of a given size\nrun_in_parallel() {\n    local batch_size=$1\n    shift\n    local i=0\n    local command\n\n    for command; do\n        eval \"$command\" &\n        ((i = (i + 1) % batch_size))\n        if [ \"$i\" -eq 0 ]; then\n            wait\n        fi\n    done\n\n    # Wait for any remaining jobs to finish\n    wait\n}\n\n# Run the download commands in parallel in batches of 2\nrun_in_parallel 6 \"${download_commands[@]}\"\n\necho \"All files downloaded and saved in their respective directories\""
  },
  {
    "path": "dev/eval/README.md",
    "content": "# eleuther eval readme\n\nThe goal here is to run the Eleuther Eval harness exactly in the same way as that used in the [huggingface LLM Leaderboard](https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard).\n\nThe starting point is a `.bin` file trained by llm.c. We now have to export it to a huggingface model and then evaluate it.\n\nTo export the model, use [export_hf.py](export_hf.py). See its documentation up top. Eample usage, from this directory:\n\n```bash\ncd dev/eval\npython export_hf.py --input model.bin --output output_dir\n```\n\nWhere you point to your model .bin file, and huggingface files get written to output_dir. The script can optionally also upload to huggingface hub. One more post-processing that is advisable is to go into the `output_dir`, open up the `config.json` there and add one more entry into the json object:\n\n```\n\"_attn_implementation\": \"flash_attention_2\"\n```\n\nTo use FlashAttention 2. We had trouble evaluating in bfloat16 without using FlashAttention 2 (the scores are much lower, and this was never fully resolved). This is a temporary hack/workaround.\n\nNow that we have the model in huggingface format, we download the Eleuther Eval Harness repo and run it. Head over to the parent/root directory of the llm.c repo and:\n\n```bash\ngit clone https://github.com/EleutherAI/lm-evaluation-harness/\ncd lm-evaluation-harness\ngit checkout b281b0921b636bc36ad05c0b0b0763bd6dd43463\npip install -e .\n```\n\nAnd then run the run_eval.sh script:\n\n```bash\n./dev/eval/run_eval.sh output_dir result_dir\n```\n\nWhere output_dir can either be local output dir (above), or a huggingface repo name.This will write eval json objects to `./lm-evaluation-harness/results/results_dir`. It will print the results into console, e.g. for a 774M model we see:\n\n```\n----------------------------------------\narc_challenge_25shot.json      : 30.4608\ngsm8k_5shot.json               : 0.1516\nhellaswag_10shot.json          : 57.8072\nmmlu_5shot.json                : 25.8682\ntruthfulqa_0shot.json          : 35.7830\nwinogrande_5shot.json          : 59.3528\n----------------------------------------\nAverage Score                  : 34.9039\n```\n\nBut you can additionally get these results later by running `summarize_eval.py`:\n\n```bash\npython dev/eval/summarize_eval.py lm-evaluation-harness/results/results_dir\n```\n\nThe same information will be printed again.\n\nFor some reason, the evaluation is quite expensive and runs for somewhere around 1-3 hours, even though it should be a few minutes at most. This has not been satisfyingly resolved so far."
  },
  {
    "path": "dev/eval/export_hf.py",
    "content": "\"\"\"\nScript to convert GPT2 models from llm.c binary format to Hugging Face\n\nIt can optinally upload to your account on Hugging Face if you have the CLI:\n  pip install -U \"huggingface_hub[cli]\"\n  huggingface-cli login\n\nExport to a local HF model:\n  python export_hf.py --input input_file.bin --output output_dir\n\nExport to a local HF model and also push to your account on Hugging Face:\n  python export_hf.py --input input_file.bin --output output_dir --push true\n\"\"\"\n\nimport numpy as np\nimport torch\nimport argparse, sys\nfrom transformers import GPT2Config, GPT2Tokenizer, GPT2LMHeadModel\n\n# -----------------------------------------------------------------------------\n# Tensor functions for both bfloat16 (from int16) and normal float32\n# Both return float32 tensors\n\ndef tensor_bf16(data_int16, transpose=False):\n    if transpose:\n        data_int16 = data_int16.transpose(1,0)\n    return torch.tensor(data_int16).view(torch.bfloat16).to(torch.float32)\n\ndef tensor_fp32(data_float32, transpose=False):\n    if transpose:\n        data_float32 = data_float32.transpose(1,0)\n    return torch.tensor(data_float32).view(torch.float32)\n\n# -----------------------------------------------------------------------------\n# Main conversion function\n\ndef convert(filepath, output, push_to_hub=False, out_dtype=\"bfloat16\"):\n    print(f\"Converting model {filepath} to {output} in {out_dtype} format and pushing to Hugging Face: {push_to_hub}\")\n\n    f = open(filepath, 'rb')\n    # Read in our header, checking the magic number and version\n    # version 3 = fp32, padded vocab\n    # version 5 = bf16, padded vocab\n    model_header = np.frombuffer(f.read(256*4), dtype=np.int32)\n    if model_header[0] != 20240326:\n        print(\"ERROR: magic number mismatch in the data .bin file!\")\n        exit(1)\n    version = model_header[1]\n    if not version in [3, 5]:\n        print(\"Bad version in model file\")\n        exit(1)\n\n    # Load in our model parameters\n    maxT = model_header[2].item() # max sequence length\n    V = model_header[3].item() # vocab size\n    L =  model_header[4].item() # num layers\n    H = model_header[5].item() # num heads\n    C = model_header[6].item() # channels\n    Vp = model_header[7].item() # padded vocab size\n\n    print(f\"{version=}, {maxT=}, {V=}, {Vp=}, {L=}, {H=}, {C=}\")\n\n    # Define the shapes of our parameters\n    shapes = {\n        'wte': (Vp, C),\n        'wpe': (maxT, C),\n        'ln1w': (L, C),\n        'ln1b': (L, C),\n        'qkvw': (L, 3 * C, C),\n        'qkvb': (L, 3 * C),\n        'attprojw': (L, C, C),\n        'attprojb': (L, C),\n        'ln2w': (L, C),\n        'ln2b': (L, C),\n        'fcw': (L, 4 * C, C),\n        'fcb': (L, 4 * C),\n        'fcprojw': (L, C, 4 * C),\n        'fcprojb': (L, C),\n        'lnfw': (C,),\n        'lnfb': (C,),\n    }\n\n    # Load in our weights given our parameter shapes\n    dtype = np.float32 if version == 3 else np.int16\n    w = {}\n    for key, shape in shapes.items():\n        num_elements = np.prod(shape)\n        data = np.frombuffer(f.read(num_elements * np.dtype(dtype).itemsize), dtype=dtype)\n        w[key] = data.reshape(shape)\n        # The binary file saves the padded vocab - drop the padding back to GPT2 size\n        if shape[0] == Vp:\n            w[key] = w[key].reshape(shape)[:(V-Vp), :]\n    # Ensure the file is fully read and then close\n    assert f.read() == b''\n    f.close()\n\n    # Map to our model dict, the tensors at this stage are always fp32\n    mk_tensor = {\n        3 : tensor_fp32,\n        5 : tensor_bf16,\n    }[version]\n    model_dict = {}\n    model_dict['transformer.wte.weight'] = mk_tensor(w['wte'])\n    model_dict['transformer.wpe.weight'] = mk_tensor(w['wpe'])\n    model_dict['lm_head.weight'] = model_dict['transformer.wte.weight'] # Tie weights\n    for i in range(L):\n        model_dict[f'transformer.h.{i}.ln_1.weight'] = mk_tensor(w['ln1w'][i])\n        model_dict[f'transformer.h.{i}.ln_1.bias'] = mk_tensor(w['ln1b'][i])\n        model_dict[f'transformer.h.{i}.attn.c_attn.weight'] = mk_tensor(w['qkvw'][i], True)\n        model_dict[f'transformer.h.{i}.attn.c_attn.bias'] = mk_tensor(w['qkvb'][i])\n        model_dict[f'transformer.h.{i}.attn.c_proj.weight'] = mk_tensor(w['attprojw'][i], True)\n        model_dict[f'transformer.h.{i}.attn.c_proj.bias'] = mk_tensor(w['attprojb'][i])\n        model_dict[f'transformer.h.{i}.ln_2.weight'] = mk_tensor(w['ln2w'][i])\n        model_dict[f'transformer.h.{i}.ln_2.bias'] = mk_tensor(w['ln2b'][i])\n        model_dict[f'transformer.h.{i}.mlp.c_fc.weight'] = mk_tensor(w['fcw'][i], True)\n        model_dict[f'transformer.h.{i}.mlp.c_fc.bias'] = mk_tensor(w['fcb'][i])\n        model_dict[f'transformer.h.{i}.mlp.c_proj.weight'] = mk_tensor(w['fcprojw'][i], True)\n        model_dict[f'transformer.h.{i}.mlp.c_proj.bias'] = mk_tensor(w['fcprojb'][i])\n    model_dict['transformer.ln_f.weight'] = mk_tensor(w['lnfw'])\n    model_dict['transformer.ln_f.bias'] = mk_tensor(w['lnfb'])\n\n    # Create a GPT-2 model instance, in the requested dtype\n    config = GPT2Config(vocab_size = V,\n                        n_positions = maxT,\n                        n_ctx = maxT,\n                        n_embd = C,\n                        n_layer = L,\n                        n_head = H)\n    model = GPT2LMHeadModel(config)\n    if out_dtype == \"bfloat16\":\n        model = model.to(torch.bfloat16)\n\n    # Set the model dict and save\n    model.load_state_dict(model_dict)\n    model.save_pretrained(output, max_shard_size=\"5GB\", safe_serialization=True)\n\n    # Copy over a standard gpt2 tokenizer\n    tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n    tokenizer.save_pretrained(output)\n\n    if push_to_hub:\n        print(f\"Uploading {output} to Hugging Face\")\n        model.push_to_hub(output)\n        tokenizer.push_to_hub(output)\n\ndef spin(output):\n    print(\"Taking the exported model for a spin...\")\n    print('-'*80)\n    from transformers import AutoModelForCausalLM, AutoTokenizer\n    tokenizer = AutoTokenizer.from_pretrained(output)\n    model = AutoModelForCausalLM.from_pretrained(output, attn_implementation=\"flash_attention_2\", torch_dtype=torch.bfloat16, device_map='cuda')\n    model.eval()\n    tokens = tokenizer.encode(\"During photosynthesis in green plants\", return_tensors=\"pt\")\n    tokens = tokens.to('cuda')\n    output = model.generate(tokens, max_new_tokens=64, repetition_penalty=1.3)\n    samples = tokenizer.batch_decode(output)\n    for sample in samples:\n        print('-'*30)\n        print(sample)\n\n# -----------------------------------------------------------------------------\n\nif __name__== '__main__':\n    parser=argparse.ArgumentParser()\n    parser.add_argument(\"--input\", \"-i\", help=\"The name of the llm.c model.bin file\", type=str, required=True)\n    parser.add_argument(\"--output\",\"-o\",  help=\"The Hugging Face output model directory\", type=str, required=True)\n    parser.add_argument(\"--dtype\", \"-d\", help=\"Output as either float32 or bfloat16 (default)\", type=str, default=\"bfloat16\")\n    parser.add_argument(\"--push\", \"-p\", help=\"Push the model to your Hugging Face account\", type=bool, default=False)\n    parser.add_argument(\"--spin\", \"-s\", help=\"Take the model for a spin at the end?\", type=bool, default=True)\n    args = parser.parse_args()\n    convert(args.input, args.output, args.push, args.dtype)\n    if args.spin:\n        spin(args.output)\n"
  },
  {
    "path": "dev/eval/run_eval.sh",
    "content": "# https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard\n# (See About tab -> REPRODUCIBILITY)\n\n# This script is intended to be run from the parent/root directory of llm.c repo.\n\n# Clone the evaluation harness:\n\n# git clone https://github.com/EleutherAI/lm-evaluation-harness/\n# cd lm-evaluation-harness\n# git checkout b281b0921b636bc36ad05c0b0b0763bd6dd43463\n# pip install -e .\n\n# Then return to the parent directory and run this script\n\n# cd ..\n# ./dev/eval/run_eval.sh [model_name] [result_name]\n\n# where model_name is either a HF model such as openai-community/gpt2 or a local path such as ./gpt2-124M-run1\n# and result_name is the name of the folder under lm-evaluation-harness/results to store the evaluations\n\n# Since the evals can take a couple of hours to run, depending on the model size, you may wish to\n# run within a \"screen\" session or by using nohup to run the script:\n\n# nohup ./dev/eval/run_eval.sh [model_name] [result_name] > run.txt 2> err.txt &\n\nif [ -z \"$1\" ]; then\n    echo \"Error: missing HuggingFace model name or path to local model\"\n    echo \"./run_eval.sh hf_account/model_name my_result\"\n  exit 1\nfi\nif [ -z \"$2\" ]; then\n  echo \"Error: missing output name for results\"\n    echo \"./run_eval.sh hf_account/model_name my_result\"\n  exit 1\nfi\n\nexport MODEL=\"$(realpath -s \"$1\")\"\nexport RESULT=\"$2\"\necho \"Evaluating model $MODEL\"\necho \"Saving results to ./lm-evaluation-harness/results/$RESULT\"\n\ncd lm-evaluation-harness\n\npython main.py --model hf-causal-experimental --model_args pretrained=$MODEL,use_accelerate=True,trust_remote_code=True --tasks truthfulqa_mc --batch_size 1 --no_cache --write_out --output_path results/$RESULT/truthfulqa_0shot.json --device cuda\npython main.py --model hf-causal-experimental --model_args pretrained=$MODEL,use_accelerate=True,trust_remote_code=True --tasks winogrande --batch_size 1 --no_cache --write_out --output_path results/$RESULT/winogrande_5shot.json --device cuda --num_fewshot 5\npython main.py --model hf-causal-experimental --model_args pretrained=$MODEL,use_accelerate=True,trust_remote_code=True --tasks arc_challenge --batch_size 1 --no_cache --write_out --output_path results/$RESULT/arc_challenge_25shot.json --device cuda --num_fewshot 25\npython main.py --model hf-causal-experimental --model_args pretrained=$MODEL,use_accelerate=True,trust_remote_code=True --tasks hellaswag --batch_size 1 --no_cache --write_out --output_path results/$RESULT/hellaswag_10shot.json --device cuda --num_fewshot 10\npython main.py --model hf-causal-experimental --model_args pretrained=$MODEL,use_accelerate=True,trust_remote_code=True --tasks gsm8k --batch_size 1 --no_cache --write_out --output_path results/$RESULT/gsm8k_5shot.json --device cuda --num_fewshot 5\npython main.py --model hf-causal-experimental --model_args pretrained=$MODEL,use_accelerate=True,trust_remote_code=True --tasks hendrycksTest-abstract_algebra,hendrycksTest-anatomy,hendrycksTest-astronomy,hendrycksTest-business_ethics,hendrycksTest-clinical_knowledge,hendrycksTest-college_biology,hendrycksTest-college_chemistry,hendrycksTest-college_computer_science,hendrycksTest-college_mathematics,hendrycksTest-college_medicine,hendrycksTest-college_physics,hendrycksTest-computer_security,hendrycksTest-conceptual_physics,hendrycksTest-econometrics,hendrycksTest-electrical_engineering,hendrycksTest-elementary_mathematics,hendrycksTest-formal_logic,hendrycksTest-global_facts,hendrycksTest-high_school_biology,hendrycksTest-high_school_chemistry,hendrycksTest-high_school_computer_science,hendrycksTest-high_school_european_history,hendrycksTest-high_school_geography,hendrycksTest-high_school_government_and_politics,hendrycksTest-high_school_macroeconomics,hendrycksTest-high_school_mathematics,hendrycksTest-high_school_microeconomics,hendrycksTest-high_school_physics,hendrycksTest-high_school_psychology,hendrycksTest-high_school_statistics,hendrycksTest-high_school_us_history,hendrycksTest-high_school_world_history,hendrycksTest-human_aging,hendrycksTest-human_sexuality,hendrycksTest-international_law,hendrycksTest-jurisprudence,hendrycksTest-logical_fallacies,hendrycksTest-machine_learning,hendrycksTest-management,hendrycksTest-marketing,hendrycksTest-medical_genetics,hendrycksTest-miscellaneous,hendrycksTest-moral_disputes,hendrycksTest-moral_scenarios,hendrycksTest-nutrition,hendrycksTest-philosophy,hendrycksTest-prehistory,hendrycksTest-professional_accounting,hendrycksTest-professional_law,hendrycksTest-professional_medicine,hendrycksTest-professional_psychology,hendrycksTest-public_relations,hendrycksTest-security_studies,hendrycksTest-sociology,hendrycksTest-us_foreign_policy,hendrycksTest-virology,hendrycksTest-world_religions --batch_size 1 --no_cache --write_out --output_path results/$RESULT/mmlu_5shot.json --device cuda --num_fewshot 5\n\ncd ..\npython dev/eval/summarize_eval.py lm-evaluation-harness/results/$RESULT\n"
  },
  {
    "path": "dev/eval/summarize_eval.py",
    "content": "# example run command\n# python dev/eval/summarize_eval.py lm-evaluation-harness/results/result774M\n# this script is optional, the run_eval.sh should already print these\n# but this script can be used to re-print them\n\nimport json, sys\n\nRESULT = sys.argv[1]\nprint(\"-\"*40)\n\nkey = {\"arc_challenge_25shot.json\": \"acc_norm\",\n       \"gsm8k_5shot.json\": \"acc\",\n       \"hellaswag_10shot.json\": \"acc_norm\",\n       \"mmlu_5shot.json\": \"acc\",\n       \"truthfulqa_0shot.json\": \"mc2\",\n       \"winogrande_5shot.json\": \"acc\"\n       }\n\ntotal = 0\nfor test in [\"arc_challenge_25shot.json\", \"gsm8k_5shot.json\", \"hellaswag_10shot.json\", \"mmlu_5shot.json\", \"truthfulqa_0shot.json\", \"winogrande_5shot.json\"]:\n    data = json.loads(open(\"./%s/%s\"%(RESULT, test)).read())\n    r_count = 0\n    r_total = 0\n    for test_name in data['results']:\n      r_count += 1\n      r_total += data['results'][test_name][key[test]]\n    score = (r_total*100)/r_count\n    print(f\"{test:<30} : {score:.4f}\")\n    total += score\naverage = total / 6.0\nprint(\"-\"*40)\nprint(f\"Average Score                  : {average:.4f}\")\n"
  },
  {
    "path": "dev/loss_checker_ci.py",
    "content": "# Description: A script to compare numbers in a file with fixed values and check for accuracy within a specified percent difference.\n# Usage: python loss_checker_ci.py -f <file_path> -s <col_start> -e <col_end> -a <percent_accuracy>\n# Example: python dev/loss_checker_ci.py -f train_gpt2cu_fp32_precision.txt -s 20 -e 28 -a 10.0\nimport sys\nimport argparse\n\ndef read_numbers_from_file(file_path, col_start, col_end):\n    try:\n        numbers = []\n        with open(file_path, 'r') as file:\n            lines = file.readlines()\n            start_index = None\n            for i, line in enumerate(lines):\n                if \"step    1/10\" in line:\n                    start_index = i\n                    break\n\n            if start_index is None:\n                print(\"Error: Could not find the string 'step    1/10' in the file.\")\n                return None\n\n            # Read 10 rows starting from the identified start row\n            for line in lines[start_index:start_index + 10]:\n                # Extracting the specified columns\n                number = float(line[col_start:col_end].strip())\n                numbers.append(number)\n        return numbers\n    except Exception as e:\n        print(f\"Error reading the file: {e}\")\n        return None\n\ndef compare_numbers(read_values, fixed_values, percent_accuracy):\n    for i in range(len(read_values)):\n        read_value = read_values[i]\n        fixed_value = fixed_values[i]\n        percent_difference = ((read_value - fixed_value) / fixed_value) * 100\n        print(f\"Fixed Value: {fixed_value}, Read Value: {read_value}, Percent Difference: {percent_difference:.2f}%\")\n        if abs(percent_difference) > percent_accuracy:\n            print(f\"Error: Percent difference {percent_difference:.2f}% exceeds the allowed accuracy of {percent_accuracy}%\")\n            return 1\n    print(\"Success: All values are within the allowed accuracy.\")\n    return 0\n\ndef main():\n    parser = argparse.ArgumentParser(description='Compare numbers in a file with fixed values.')\n    parser.add_argument('-f', '--file', required=True, help='Path to the input file')\n    parser.add_argument('-s', '--col_start', type=int, required=True, help='Starting column index (0-based)')\n    parser.add_argument('-e', '--col_end', type=int, required=True, help='Ending column index (0-based)')\n    parser.add_argument('-a', '--percent_accuracy', type=float, required=True, help='Allowed percent accuracy for comparison')\n\n    args = parser.parse_args()\n\n    # Read numbers from file\n    read_values = read_numbers_from_file(args.file, args.col_start, args.col_end)\n    if read_values is None:\n        return 1\n\n    # Use values from test_gpt2.cu for fp32 precision\n    fixed_values = [5.270009,4.060681,3.320085,2.717550,2.181066,1.653923,1.168050,0.736873,0.401021,0.187493];\n\n    # Compare the numbers and check accuracy\n    result = compare_numbers(read_values, fixed_values, args.percent_accuracy)\n    return result\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "dev/test/Makefile",
    "content": "CC ?= gcc\n# example: make test_dataloader TEST_CFLAGS=-fsanitize=address -fno-omit-frame-pointer \nCFLAGS = -Ofast -Wno-unused-result -Wno-ignored-pragmas -Wno-unknown-attributes -g\nCFLAGS += $(TEST_CFLAGS)\nLDFLAGS =\nLDLIBS = -lm\nINCLUDES =\nCFLAGS_COND = -march=native\n\n# Find nvcc\nSHELL_UNAME = $(shell uname)\nREMOVE_FILES = rm -f\nOUTPUT_FILE = -o $@\nCUDA_OUTPUT_FILE = -o $@\n\n# NVCC flags\n# -t=0 is short for --threads, 0 = number of CPUs on the machine\nNVCC_FLAGS = -O3 -t=0 --use_fast_math -std=c++17\nNVCC_LDFLAGS = -lcublas -lcublasLt\nNVCC_INCLUDES =\nNVCC_LDLIBS =\nNVCC_CUDNN =\n# By default we don't build with cudnn because it blows up compile time from a few seconds to ~minute\nUSE_CUDNN ?= 0\n\n# We will place .o files in the `build` directory (create it if it doesn't exist)\nBUILD_DIR = build\n$(shell mkdir -p $(BUILD_DIR))\nREMOVE_BUILD_OBJECT_FILES := rm -f $(BUILD_DIR)/*.o\n\n# Function to check if a file exists in the PATH\ndefine file_exists_in_path\n  $(which $(1) 2>/dev/null)\nendef\n\nifneq ($(CI),true) # if not in CI, then use the GPU query\n  ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY=\n    ifneq ($(call file_exists_in_path, __nvcc_device_query),)\n      GPU_COMPUTE_CAPABILITY = $(shell __nvcc_device_query)\n      GPU_COMPUTE_CAPABILITY := $(strip $(GPU_COMPUTE_CAPABILITY))\n    endif\n  endif\nendif\n\n# set to defaults if - make GPU_COMPUTE_CAPABILITY= otherwise use the compute capability detected above\nifneq ($(GPU_COMPUTE_CAPABILITY),)\n  NVCC_FLAGS += --generate-code arch=compute_$(GPU_COMPUTE_CAPABILITY),code=[compute_$(GPU_COMPUTE_CAPABILITY),sm_$(GPU_COMPUTE_CAPABILITY)]\nendif\n\n# autodect a lot of various supports on current platform\n$(info ---------------------------------------------)\n\nNVCC := $(shell which nvcc 2>/dev/null)\n\n# Check and include cudnn if available\n# You can override the path to cudnn frontend by setting CUDNN_FRONTEND_PATH on the make command line\n# By default, we look for it in HOME/cudnn-frontend/include and ./cudnn-frontend/include\n# Refer to the README for cuDNN install instructions\nifeq ($(USE_CUDNN), 1)\n  ifeq ($(shell [ -d $(HOME)/cudnn-frontend/include ] && echo \"exists\"), exists)\n    $(info ✓ cuDNN found, will run with flash-attention)\n    CUDNN_FRONTEND_PATH ?= $(HOME)/cudnn-frontend/include\n  else ifeq ($(shell [ -d cudnn-frontend/include ] && echo \"exists\"), exists)\n    $(info ✓ cuDNN found, will run with flash-attention)\n    CUDNN_FRONTEND_PATH ?= cudnn-frontend/include\n  else\n    $(error ✗ cuDNN not found. See the README for install instructions and the Makefile for hard-coded paths)\n  endif\n  NVCC_INCLUDES += -I$(CUDNN_FRONTEND_PATH)\n  NVCC_LDFLAGS += -lcudnn\n  NVCC_FLAGS += -DENABLE_CUDNN\n  NVCC_CUDNN = $(BUILD_DIR)/cudnn_att.o\nelse\n  $(info → cuDNN is manually disabled by default, run make with `USE_CUDNN=1` to try to enable)\nendif\n\n# Check if OpenMP is available\n# This is done by attempting to compile an empty file with OpenMP flags\n# OpenMP makes the code a lot faster so I advise installing it\n# e.g. on MacOS: brew install libomp\n# e.g. on Ubuntu: sudo apt-get install libomp-dev\n# later, run the program by prepending the number of threads, e.g.: OMP_NUM_THREADS=8 ./gpt2\n# First, check if NO_OMP is set to 1, if not, proceed with the OpenMP checks\nifeq ($(NO_OMP), 1)\n  $(info OpenMP is manually disabled)\nelse\n  ifneq ($(OS), Windows_NT)\n    # Check for OpenMP support in GCC or Clang on Linux\n    ifeq ($(shell echo | $(CC) -fopenmp -x c -E - > /dev/null 2>&1; echo $$?), 0)\n      CFLAGS += -fopenmp -DOMP\n      LDLIBS += -lgomp\n      $(info ✓ OpenMP found)\n    else\n      $(info ✗ OpenMP not found)\n    endif\n  endif\nendif\n\n# Check if OpenMPI and NCCL are available, include them if so, for multi-GPU training\nifeq ($(NO_MULTI_GPU), 1)\n  $(info → Multi-GPU (OpenMPI + NCCL) is manually disabled)\nelse\n  ifeq ($(shell [ -d /usr/lib/x86_64-linux-gnu/openmpi/lib/ ] && [ -d /usr/lib/x86_64-linux-gnu/openmpi/include/ ] && echo \"exists\"), exists)\n    $(info ✓ OpenMPI found, OK to train with multiple GPUs)\n    NVCC_INCLUDES += -I/usr/lib/x86_64-linux-gnu/openmpi/include\n    NVCC_LDFLAGS += -L/usr/lib/x86_64-linux-gnu/openmpi/lib/\n    NVCC_LDLIBS += -lmpi -lnccl\n    NVCC_FLAGS += -DMULTI_GPU\n  else\n    $(info ✗ OpenMPI is not found, disabling multi-GPU support)\n    $(info ---> On Linux you can try install OpenMPI with `sudo apt install openmpi-bin openmpi-doc libopenmpi-dev`)\n  endif\nendif\n\n# Precision settings, default to bf16 but ability to override\nifeq ($(MAKECMDGOALS), clean)\n  PRECISION=BF16 \nendif\n\nVALID_PRECISIONS := FP32 FP16 BF16\nifeq ($(filter $(PRECISION),$(VALID_PRECISIONS)),)\n  $(error Invalid precision $(PRECISION), valid precisions are $(VALID_PRECISIONS))\nendif\nifeq ($(PRECISION), FP32)\n  PFLAGS = -DENABLE_FP32\nelse ifeq ($(PRECISION), FP16)\n  PFLAGS = -DENABLE_FP16\nelse\n  PFLAGS = -DENABLE_BF16\nendif\n\n# PHONY means these targets will always be executed\n.PHONY: all clean\n\n# Add targets\nTARGETS = test_dataloader\n\n# Dependency files\ntest_dataloader_dependencies = test_dataloader.d\nHEADER_DEPENDENCIES = $(test_dataloader_dependencies)\n\n# Conditional inclusion of CUDA targets\nifeq ($(NVCC),)\n    $(info ✗ nvcc not found, skipping GPU/CUDA builds)\nelse\n    $(info ✓ nvcc found, including GPU/CUDA support)\n    TARGETS += \nendif\n\n$(info ---------Build Configuration Complete - Build Targets -------------------------)\n\nall: $(TARGETS)\n\n# Generate dependency files\n%.d: %.c\n\t$(CC) $(CFLAGS) -MMD -MP -MF $@ -c $<\n\n# Include the dependency files\n-include test_dataloader.d\n\ntest_dataloader: test_dataloader.c\n\t$(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) -MMD -MP $^ $(LDLIBS) $(OUTPUT_FILE)\n\nclean:\n\t$(REMOVE_FILES) $(TARGETS) *.d *.o\n\t$(REMOVE_BUILD_OBJECT_FILES)\n"
  },
  {
    "path": "dev/test/device_file_io.cu",
    "content": "/*\nTests device <-> file IO functions\n\ncompile and run as (from dev/test directory)\nnvcc -o device_file_io device_file_io.cu && ./device_file_io\n*/\n\n\n#include \"../../llmc/cuda_common.h\"\n#include <vector>\n#include <random>\n#include <cstdio>\n#include <algorithm>\n\nvoid test(size_t nelem, size_t wt_buf_size, size_t rd_buf_size) {\n\n    float* data;\n    cudaCheck(cudaMalloc(&data, nelem*sizeof(float)));\n\n    // generate random array\n    std::vector<float> random_data(nelem);\n    std::mt19937 rng(42);\n    std::uniform_real_distribution<float> dist(-100.f, 100.f);\n    std::generate(random_data.begin(), random_data.end(), [&](){ return dist(rng); });\n\n    cudaCheck(cudaMemcpy(data, random_data.data(), random_data.size()*sizeof(float), cudaMemcpyHostToDevice));\n\n    cudaStream_t stream;\n    cudaStreamCreate(&stream);\n\n    FILE* tmp = fopenCheck(\"tmp.bin\", \"w\");\n    device_to_file(tmp, data, nelem * sizeof(float), wt_buf_size, stream);\n    fcloseCheck(tmp);\n\n\n    float* reload;\n    cudaCheck(cudaMalloc(&reload, nelem*sizeof(float)));\n\n    tmp  = fopenCheck(\"tmp.bin\", \"r\");\n    file_to_device(reload, tmp, nelem * sizeof(float), rd_buf_size, stream);\n    fcloseCheck(tmp);\n\n    std::vector<float> cmp(nelem);\n    cudaCheck(cudaMemcpy(cmp.data(), reload, nelem * sizeof(float), cudaMemcpyDeviceToHost));\n    for(int i = 0; i < nelem; ++i) {\n        if(random_data[i] != cmp[i])  {\n            fprintf(stderr, \"FAIL: Mismatch at position %d: %f vs %f\\n\", i, random_data[i], cmp[i]);\n            remove(\"tmp.bin\");\n            exit(EXIT_FAILURE);\n        }\n    }\n\n    cudaCheck(cudaFree(reload));\n    cudaCheck(cudaFree(data));\n    remove(\"tmp.bin\");\n}\n\nint main() {\n    test(1025, 10000, 10000);           // buffers larger than data\n    test(1025, 1024, 513);              // different and smaller\n    test(500, 500*sizeof(float),\n         500*sizeof(float));            // exact match\n    test(125'000, 10000, 10000);        // large array\n}"
  },
  {
    "path": "dev/test/test_dataloader.c",
    "content": "/*\nTests our DataLoader\n\ncompile and run as (from dev/test directory)\ngcc -O3 -I../../llmc -o test_dataloader test_dataloader.c -lm && ./test_dataloader\n\nTODOs:\n- test load/save state of DataLoader\n*/\n#include <unistd.h>\n#include \"../../llmc/dataloader.h\"\n\n#define SHARD_NAME_LEN 64\nchar shard_name[SHARD_NAME_LEN];\nconst int num_tokens = 140;\nint num_shards = 4;\n\nvoid check_range(const int *tokens, const int start, const int end, const char *file, int line) {\n    // checks that the tokens[0, ... end-start] are the range [start, end)\n    int n = end - start;\n    for (int i = 0; i < n; i++) {\n        int token = tokens[i];\n        if (token != start + i) {\n            fprintf(stderr, \"Error: tokens[%d] = %d, expected %d\\n\", i, token, start + i);\n            fprintf(stderr, \"Error details:\\n\");\n            fprintf(stderr, \"  File: %s\\n\", file);\n            fprintf(stderr, \"  Line: %d\\n\", line);\n            exit(EXIT_FAILURE);\n        }\n    }\n    // printf(\"tokens in range [%d, %d) OK\\n\", start, end);\n}\n#define checkRange(tokens, start, end) check_range(tokens, start, end, __FILE__, __LINE__)\n\nvoid check_equals(const int *tokens, const int n, const int expected, const char *file, int line) {\n    // checks that the tokens[0, ... n] are all equal to expected\n    for (int i = 0; i < n; i++) {\n        int token = tokens[i];\n        if (token != expected) {\n            fprintf(stderr, \"Error: tokens[%d] = %d, expected %d\\n\", i, token, expected);\n            fprintf(stderr, \"Error details:\\n\");\n            fprintf(stderr, \"  File: %s\\n\", file);\n            fprintf(stderr, \"  Line: %d\\n\", line);\n            exit(EXIT_FAILURE);\n        }\n    }\n    // printf(\"tokens all equal to %d OK\\n\", expected);\n}\n#define checkEquals(tokens, n, expected) check_equals(tokens, n, expected, __FILE__, __LINE__)\n\nvoid test_simple(void) {\n    /*\n    Tests the simplest DataLoader functionality:\n    - multi-shard\n    - single-process\n    - not shuffled\n    DataLoader should just return all the tokens in order\n    */\n    printf(\"test_simple... \");\n    int B = 4;\n    int T = 8;\n    int process_rank = 0;\n    int num_processes = 1;\n    int should_shuffle = 0;\n    snprintf(shard_name, SHARD_NAME_LEN, \"shard_????.bin\");\n    DataLoader loader;\n    dataloader_init(&loader, shard_name, B, T, process_rank, num_processes, should_shuffle);\n\n    int batches_fit = num_tokens / (B * T); // number of batches that fit per shard\n    int BT = B * T;\n    int num_epochs = 4;\n    for (int e = 0; e < num_epochs; e++) { // epoch\n        for (int s = 0; s < num_shards; s++) { // shard\n            int start = s * num_tokens;\n            for (int b = 0; b < batches_fit; b++) { // batch\n                dataloader_next_batch(&loader);\n                checkRange(loader.inputs, start, start + BT);\n                checkRange(loader.targets, start + 1, start + BT + 1);\n                start += BT;\n            }\n        }\n    }\n    dataloader_free(&loader);\n    printf(\"OK\\n\");\n}\n\nvoid test_multiprocess_simple(void) {\n    /*\n    Same as simple above, but using 2 processes.\n    (which we of course use in a serial, single process way here)\n    The DataLoaders simply pull chunks of consecutive tokens, so\n    we expect them to alternate in the \"token space\".\n    */\n    printf(\"test_multiprocess_simple... \");\n    int B = 4;\n    int T = 8;\n    int num_processes = 2;\n    int should_shuffle = 0;\n    snprintf(shard_name, SHARD_NAME_LEN, \"shard_????.bin\");\n    DataLoader loader0, loader1;\n    dataloader_init(&loader0, shard_name, B, T, 0, num_processes, should_shuffle);\n    dataloader_init(&loader1, shard_name, B, T, 1, num_processes, should_shuffle);\n\n    int batches_fit = num_tokens / (B * T * num_processes); // number of batches that fit per shard\n    int BT = B * T;\n    int num_epochs = 4;\n    for (int e = 0; e < num_epochs; e++) { // epoch\n        for (int s = 0; s < num_shards; s++) { // shard\n            int start = s * num_tokens;\n            for (int b = 0; b < batches_fit; b++) { // batch\n                dataloader_next_batch(&loader0);\n                dataloader_next_batch(&loader1);\n                checkRange(loader0.inputs, start, start + BT);\n                checkRange(loader1.inputs, start + BT, start + 2*BT);\n                checkRange(loader0.targets, start + 1, start + BT + 1);\n                checkRange(loader1.targets, start + BT + 1, start + 2*BT + 1);\n                start += 2*BT;\n            }\n        }\n    }\n\n    dataloader_free(&loader0);\n    dataloader_free(&loader1);\n    printf(\"OK\\n\");\n}\n\nvoid test_shuffled(void) {\n    /*\n    Tests the DataLoader when using shuffled:\n    - multi-shard\n    - single-process\n    - shuffled!\n    DataLoader should return all the tokens, but in randperm order.\n    So all we check is that we see all the tokens we expect to see,\n    the correct number of times.\n    */\n    printf(\"test_shuffled... \");\n    int B = 4;\n    int T = 8;\n    int process_rank = 0;\n    int num_processes = 1;\n    int should_shuffle = 1; // should shuffle bit turn on\n    snprintf(shard_name, 64, \"shard_????.bin\");\n    DataLoader loader;\n    dataloader_init(&loader, shard_name, B, T, process_rank, num_processes, should_shuffle);\n\n    // get batches from the dataloader and keep stats on what tokens we see\n    int total_tokens = num_shards * num_tokens;\n    int *num_seen_inputs = (int *)calloc(total_tokens, sizeof(int));\n    int *num_seen_targets = (int *)calloc(total_tokens, sizeof(int));\n    int batches_fit = num_tokens / (B * T); // number of batches that fit per shard\n    int BT = B * T;\n    int num_epochs = 4;\n    for (int e = 0; e < num_epochs; e ++) { // epoch\n        for (int s = 0; s < num_shards; s++) { // shard\n            int start = s * num_tokens;\n            for (int b = 0; b < batches_fit; b++) { // batch\n                dataloader_next_batch(&loader);\n                // count up the tokens we see\n                for (int i = 0; i < BT; i++) {\n                    int input_token = loader.inputs[i];\n                    int target_token = loader.targets[i];\n                    assert(input_token >= 0 && input_token < total_tokens);\n                    assert(target_token >= 0 && target_token < total_tokens);\n                    num_seen_inputs[input_token]++;\n                    num_seen_targets[target_token]++;\n                }\n                start += BT;\n            }\n        }\n    }\n\n    // verify that we saw all the tokens the correct number of times\n    int tokens_fit = batches_fit * BT; // number of tokens that fit per shard\n    for (int s = 0; s < num_shards; s++) {\n        int start = s * num_tokens;\n        // verify the inputs counts for this shard:\n        // - the first tokens_fit should have been seen num_epochs times\n        // - the rest of the tokens in that should should have been seen zero times\n        checkEquals(num_seen_inputs + start, tokens_fit, num_epochs);\n        checkEquals(num_seen_inputs + start + tokens_fit, num_tokens - tokens_fit, 0);\n        // verify the target counts. same thing but offset by 1\n        checkEquals(num_seen_targets + start + 1, tokens_fit, num_epochs);\n        checkEquals(num_seen_targets + start + 1 + tokens_fit,\n            (s == (num_shards - 1)) ? num_tokens - tokens_fit - 1 : num_tokens - tokens_fit,0);\n    }\n\n    dataloader_free(&loader);\n    free(num_seen_inputs);\n    free(num_seen_targets);\n    printf(\"OK\\n\");\n}\n\nvoid test_multiprocess_shuffled(void) {\n    /*\n    Tests the DataLoader when using both multiprocess and shuffled:\n    - multi-shard\n    - multi-process\n    - shuffled!\n    DataLoaders should return all the tokens, but in randperm order.\n    So all we check is that we see all the tokens we expect to see,\n    the correct number of times, over multiple epochs.\n    */\n\n    printf(\"test_multiprocess_shuffled... \");\n    int B = 4;\n    int T = 8;\n    const int num_processes = 2;\n    int should_shuffle = 0;\n    snprintf(shard_name, SHARD_NAME_LEN, \"shard_????.bin\");\n    DataLoader loaders[num_processes];\n    for (int i = 0; i < num_processes; i++) {\n        dataloader_init(&loaders[i], shard_name, B, T, i, num_processes, should_shuffle);\n    }\n\n    // get batches from the dataloader and keep stats on what tokens we see\n    int total_tokens = num_shards * num_tokens;\n    int *num_seen_inputs = (int *)calloc(total_tokens, sizeof(int));\n    int *num_seen_targets = (int *)calloc(total_tokens, sizeof(int));\n    int batches_fit = num_tokens / (B * T * num_processes); // number of batches that fit per shard\n    int BT = B * T;\n    int num_epochs = 4;\n    for (int e = 0; e < num_epochs; e ++) { // epoch\n        for (int s = 0; s < num_shards; s++) { // shard\n            int start = s * num_tokens;\n            for (int b = 0; b < batches_fit; b++) { // batch\n                for (int n = 0; n < num_processes; n++) { // dataloader\n                    DataLoader *loader = &loaders[n];\n                    dataloader_next_batch(loader);\n                    // count up the tokens we see\n                    for (int i = 0; i < BT; i++) {\n                        int input_token = loader->inputs[i];\n                        int target_token = loader->targets[i];\n                        assert(input_token >= 0 && input_token < total_tokens);\n                        assert(target_token >= 0 && target_token < total_tokens);\n                        num_seen_inputs[input_token]++;\n                        num_seen_targets[target_token]++;\n                    }\n                    start += BT;\n                }\n            }\n        }\n    }\n\n    // verify that we saw all the tokens the correct number of times\n    int tokens_fit = batches_fit * (B * T * num_processes); // number of tokens that fit per shard\n    for (int s = 0; s < num_shards; s++) {\n        int start = s * num_tokens; // token id that starts this shard\n        // verify the inputs counts for this shard:\n        // - the first tokens_fit should have been seen num_epochs times\n        // - the rest of the tokens in that should should have been seen zero times\n        checkEquals(num_seen_inputs + start, tokens_fit, num_epochs);\n        checkEquals(num_seen_inputs + start + tokens_fit, num_tokens - tokens_fit, 0);\n        // verify the target counts. same thing but offset by 1\n        checkEquals(num_seen_targets + start + 1, tokens_fit, num_epochs);\n        checkEquals(num_seen_targets + start + 1 + tokens_fit,\n            (s == (num_shards - 1)) ? num_tokens - tokens_fit - 1 : num_tokens - tokens_fit,0);\n    }\n\n    // cleanup\n    for (int i = 0; i < num_processes; i++) {\n        dataloader_free(&loaders[i]);\n    }\n    free(num_seen_inputs);\n    free(num_seen_targets);\n    printf(\"OK\\n\");\n}\n\nint main(void) {\n\n    // generate a few dummy shards of data with incrementing tokens\n    int header[HEADER_SIZE];\n    uint16_t tokens[num_tokens];\n    for (int shard_id = 0; shard_id < num_shards; shard_id++) {\n        // ensure unique tokens across the shards for ez accounting below\n        int token_offset = shard_id * num_tokens;\n        for (int i = 0; i < num_tokens; i++) {\n            tokens[i] = token_offset + i;\n        }\n        // write the shard\n        snprintf(shard_name, SHARD_NAME_LEN, \"shard_%04d.bin\", shard_id);\n        header[0] = 20240520; // magic\n        header[1] = 1; // version\n        header[2] = num_tokens; // number of tokens within\n        FILE* shard_file = fopenCheck(shard_name, \"wb\");\n        fwrite(header, sizeof(int), HEADER_SIZE, shard_file);\n        fwrite(tokens, sizeof(uint16_t), num_tokens, shard_file);\n        fcloseCheck(shard_file);\n        printf(\"Wrote shard %s\\n\", shard_name);\n    }\n\n    test_simple();\n    test_multiprocess_simple();\n    test_shuffled();\n    test_multiprocess_shuffled();\n\n    // clean up the shards\n    for (int shard_id = 0; shard_id < num_shards; shard_id++) {\n        snprintf(shard_name, SHARD_NAME_LEN, \"shard_%04d.bin\", shard_id);\n        remove(shard_name);\n    }\n\n    return EXIT_SUCCESS;\n}"
  },
  {
    "path": "dev/test/test_outlier_detector.c",
    "content": "/*\nTests our OutlierDetector\n\ncompile and run as (from dev/test directory)\ngcc -O3 -I../../llmc -o test_outlier_detector test_outlier_detector.c -lm && ./test_outlier_detector\n*/\n\n#include <stdlib.h>\n#include \"../../llmc/outlier_detector.h\"\n\nint main(void) {\n    OutlierDetector detector;\n    init_detector(&detector);\n\n    srand(1337); // init rng\n\n    // generate OUTLIER_DETECTOR_WINDOW_SIZE * 2 random numbers between -1 and 1\n    for (int i = 0; i < OUTLIER_DETECTOR_WINDOW_SIZE * 2; i++) {\n        double val = (double)rand() / RAND_MAX * 2 - 1;  // Random number between -1 and 1\n        double zscore = update_detector(&detector, val);\n\n        printf(\"Step %d: Value = %.4f, zscore = %.4f\\n\", i, val, zscore);\n\n        // check that the first OUTLIER_DETECTOR_WINDOW_SIZE values return nan\n        if (i < OUTLIER_DETECTOR_WINDOW_SIZE) {\n            if (!isnan(zscore)) {\n                printf(\"Error: Expected nan, got %.4f\\n\", zscore);\n                return EXIT_FAILURE;\n            }\n        } else {\n            // check that the zscore is within reasonable bounds\n            if (zscore < -3.0 || zscore > 3.0) {\n                printf(\"Error: Z-score %.4f is outside of expected range\\n\", zscore);\n                return EXIT_FAILURE;\n            }\n        }\n    }\n\n    // simulate an outlier\n    double outlier = 10.0; // <--- loss spike\n    double zscore = update_detector(&detector, outlier);\n    printf(\"Outlier Step: Value = %.4f, zscore = %.4f\\n\", outlier, zscore);\n\n    // check that the z-score here is large\n    if (zscore < 5.0) {\n        printf(\"Error: Z-score %.4f is not large enough for an outlier\\n\", zscore);\n        return EXIT_FAILURE;\n    }\n\n    printf(\"OK\\n\");\n    return EXIT_SUCCESS;\n}\n"
  },
  {
    "path": "dev/unistd.h",
    "content": "// header file that is necessary to compile on Windows\n#ifndef UNISTD_H\n#define UNISTD_H\n\n#define _CRT_SECURE_NO_WARNINGS\n#define _USE_MATH_DEFINES\n#define WIN32_LEAN_AND_MEAN\n\n#include <stdio.h>\n#include <math.h>\n#include <time.h>\n#include <stdlib.h> // for malloc and free\n#include <string.h>\n#include <direct.h> // for _mkdir and _stat\n#include <io.h> // needed for _access below and _findfirst, _findnext, _findclose\n#pragma comment(lib, \"Ws2_32.lib\")  // Link Ws2_32.lib for socket functions\n#include <winsock2.h>\n\n#define CLOCK_MONOTONIC 0\nstatic inline int clock_gettime(int ignore_variable, struct timespec* tv)\n{\n    return timespec_get(tv, TIME_UTC); // TODO: not sure this is the best solution. Need to review.\n}\n\n#define OMP /* turn it on */\n#define F_OK 0\n#define access _access\n\n#define TURN_OFF_FP_FAST __pragma(float_control( precise, on, push )) // Save current setting and turn on /fp:precise\n#define TURN_ON_FP_FAST  __pragma(float_control(pop)) // Restore file's default settings\n\n#define mkdir(path, mode) _mkdir(path) /* sketchy way to get mkdir to work on windows */\n#define stat _stat\n\ntypedef struct glob_t {\n    size_t gl_pathc;    // Count of matched pathnames\n    char **gl_pathv;    // List of matched pathnames\n} glob_t;\n\nstatic inline void replace_forward_slashes(char* str) {\n    while (*str) {\n        if (*str == '/') {\n            *str = '\\\\';\n        }\n        str++;\n    }\n}\n\nstatic inline void globfree(glob_t *pglob) {\n    for (size_t i = 0; i < pglob->gl_pathc; ++i) {\n        free(pglob->gl_pathv[i]); // Free the allocated memory for each filename\n    }\n    free(pglob->gl_pathv); // Free the allocated memory for the list of filenames\n}\n\nstatic inline int glob(const char* pattern, int ignored_flags, int (*ignored_errfunc)(const char* epath, int eerrno), glob_t* pglob){\n    struct _finddata_t find_file_data;\n    char full_path[576]; // stored in pglob->gl_pathv[n]\n    char directory_path[512] = {0}; // Store the directory path from the pattern\n    char pattern_copy[512]; // Copy of the pattern to modify\n\n    strncpy_s(pattern_copy, sizeof(pattern_copy) - 1, pattern, sizeof(pattern_copy) - 1);\n\n    replace_forward_slashes (pattern_copy); // Replace forward slashes with backslashes\n\n    if (strchr(pattern_copy, '\\\\') != (void*) NULL) {\n        strncpy_s(directory_path, sizeof(directory_path) - 1, pattern_copy, strrchr(pattern_copy, '\\\\') - pattern_copy + 1);\n        directory_path[strrchr(pattern_copy, '\\\\') - pattern_copy + 1] = '\\0';\n    }\n\n    // find the first file matching the pattern in the directory\n    intptr_t find_handle = _findfirst(pattern_copy, &find_file_data);\n\n    if (find_handle == -1) {\n        return 1; // No files found\n    }\n\n    size_t file_count = 0;\n    size_t max_files = 64000; // hard-coded limit for the number of files\n\n    pglob->gl_pathv = (char **) malloc(max_files * sizeof(char*)); // freed in globfree\n\n    if (pglob->gl_pathv == NULL) {\n        _findclose(find_handle);\n        return 2; // Memory allocation failed\n    }\n\n    do {\n        if (file_count >= max_files) {\n            _findclose(find_handle);\n            return 2; // Too many files found\n            }\n\n        snprintf(full_path, sizeof(full_path), \"%s%s\", directory_path, find_file_data.name);\n\n        pglob->gl_pathv[file_count] = _strdup(full_path); // freed in globfree\n\n        if (pglob->gl_pathv[file_count] == NULL) {\n            _findclose(find_handle);\n            return 2; // Memory allocation for filename failed\n        }\n        file_count++;\n    } while (_findnext(find_handle, &find_file_data) == 0);\n\n    _findclose(find_handle);\n\n    pglob->gl_pathc = file_count;\n    return 0;\n}\n\n// dirent.h support\n\n#define MAX_PATH_LENGTH 512\ntypedef struct dirent {\n    char d_name[MAX_PATH_LENGTH];\n} dirent;\n\ntypedef struct DIR {\n    intptr_t handle;\n    struct _finddata_t findFileData;\n    int firstRead;\n} DIR;\n\nstatic inline DIR *opendir(const char *name) {\n    DIR *dir = (DIR *)malloc(sizeof(DIR));\n    if (dir == NULL) {\n        return NULL;\n    }\n\n    char searchPath[MAX_PATH_LENGTH];\n\n    snprintf(searchPath, MAX_PATH_LENGTH, \"%s\\\\*.*\", name);\n\n    dir->handle = _findfirst(searchPath, &dir->findFileData);\n    if (dir->handle == -1) {\n        free(dir);\n        return NULL;\n    }\n\n    dir->firstRead = 1;\n    return dir;\n}\n\nstatic inline struct dirent *readdir(DIR *directory) {\n    static struct dirent result;\n\n    if (directory->firstRead) {\n        directory->firstRead = 0;\n    } else {\n        if (_findnext(directory->handle, &directory->findFileData) != 0) {\n            return NULL;\n        }\n    }\n\n    strncpy(result.d_name, directory->findFileData.name, MAX_PATH_LENGTH);\n    result.d_name[MAX_PATH_LENGTH - 1] = '\\0'; // Ensure null termination\n    return &result;\n}\n\nstatic inline int closedir(DIR *directory) {\n    if (directory == NULL) {\n        return -1;\n    }\n\n    if (_findclose(directory->handle) != 0) {\n        return -1;\n    }\n\n    free(directory);\n    return 0;\n}\n#endif // UNISTD_H\n"
  },
  {
    "path": "dev/vislog.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Simple visualizer for log files written by the training loop\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import matplotlib.pyplot as plt\\n\",\n    \"%matplotlib inline\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def parse_logfile(logfile):\\n\",\n    \"    # so the tricky part we have to deal with in these log files\\n\",\n    \"    # is that the job could crash and get restarted, which will\\n\",\n    \"    # re-wind back and start re-logging older steps. So we keep\\n\",\n    \"    # all the data as dictionary and over-write old data with new\\n\",\n    \"    # and then at the end compile everything together\\n\",\n    \"\\n\",\n    \"    # read raw data\\n\",\n    \"    streams = {} # stream:str -> {step: val}\\n\",\n    \"    with open(logfile, \\\"r\\\") as f:\\n\",\n    \"        for line in f:\\n\",\n    \"            parts = line.split()\\n\",\n    \"            step = int(parts[0].split(\\\":\\\")[1])\\n\",\n    \"            stream = parts[1].split(\\\":\\\")[0]\\n\",\n    \"            val = float(parts[1].split(\\\":\\\")[1])\\n\",\n    \"            if not stream in streams:\\n\",\n    \"                streams[stream] = {}\\n\",\n    \"            d = streams[stream]\\n\",\n    \"            d[step] = val\\n\",\n    \"    # now re-represent as list of (step, val) tuples\\n\",\n    \"    streams_xy = {}\\n\",\n    \"    for k, v in streams.items():\\n\",\n    \"        # get all (step, val) items, sort them\\n\",\n    \"        xy = sorted(list(v.items()))\\n\",\n    \"        # unpack the list of tuples to tuple of lists\\n\",\n    \"        streams_xy[k] = zip(*xy)\\n\",\n    \"    # return the xs, ys lists\\n\",\n    \"    return streams_xy\\n\",\n    \"\\n\",\n    \"parse_logfile(\\\"../log124M/main.log\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import numpy as np\\n\",\n    \"\\n\",\n    \"sz = \\\"124M\\\"\\n\",\n    \"loss_baseline = {\\n\",\n    \"    \\\"124M\\\": 3.424958,\\n\",\n    \"    \\\"350M\\\": 3.083089,\\n\",\n    \"    \\\"774M\\\": 3.000580,\\n\",\n    \"    \\\"1558M\\\": 2.831273,\\n\",\n    \"}[sz]\\n\",\n    \"hella2_baseline = { # for GPT-2\\n\",\n    \"    \\\"124M\\\": 0.294463,\\n\",\n    \"    \\\"350M\\\": 0.375224,\\n\",\n    \"    \\\"774M\\\": 0.431986,\\n\",\n    \"    \\\"1558M\\\": 0.488946,\\n\",\n    \"}[sz]\\n\",\n    \"hella3_baseline = { # for GPT-3\\n\",\n    \"    \\\"124M\\\": 0.337,\\n\",\n    \"    \\\"350M\\\": 0.436,\\n\",\n    \"    \\\"774M\\\": 0.510,\\n\",\n    \"    \\\"1558M\\\": 0.547,\\n\",\n    \"}[sz]\\n\",\n    \"# assumes each model run is stored in this way\\n\",\n    \"logfile = f\\\"../log_gpt2_{sz}/main.log\\\"\\n\",\n    \"streams = parse_logfile(logfile)\\n\",\n    \"\\n\",\n    \"# optional function that smooths out the loss some\\n\",\n    \"def smooth_moving_average(signal, window_size):\\n\",\n    \"    if signal.ndim != 1:\\n\",\n    \"        raise ValueError(\\\"smooth_moving_average only accepts 1D arrays.\\\")\\n\",\n    \"    if signal.size < window_size:\\n\",\n    \"        raise ValueError(\\\"Input vector needs to be bigger than window size.\\\")\\n\",\n    \"    if window_size < 3:\\n\",\n    \"        return signal\\n\",\n    \"\\n\",\n    \"    s = np.pad(signal, (window_size//2, window_size-1-window_size//2), mode='edge')\\n\",\n    \"    w = np.ones(window_size) / window_size\\n\",\n    \"    smoothed_signal = np.convolve(s, w, mode='valid')\\n\",\n    \"    return smoothed_signal\\n\",\n    \"\\n\",\n    \"plt.figure(figsize=(16, 6))\\n\",\n    \"\\n\",\n    \"# Panel 1: losses: both train and val\\n\",\n    \"plt.subplot(121)\\n\",\n    \"xs, ys = streams[\\\"trl\\\"] # training loss\\n\",\n    \"ys = np.array(ys)\\n\",\n    \"# smooth out ys using a rolling window\\n\",\n    \"# ys = smooth_moving_average(ys, 21) # optional\\n\",\n    \"plt.plot(xs, ys, label=f'llm.c ({sz}) train loss')\\n\",\n    \"print(\\\"Min Train Loss:\\\", min(ys))\\n\",\n    \"xs, ys = streams[\\\"tel\\\"] # validation loss\\n\",\n    \"plt.plot(xs, ys, label=f'llm.c ({sz}) val loss')\\n\",\n    \"# horizontal line at GPT-2 baseline\\n\",\n    \"# we don't have GPT-3 loss on this dataset because the weights were never released\\n\",\n    \"if loss_baseline is not None:\\n\",\n    \"    plt.axhline(y=loss_baseline, color='r', linestyle='--', label=f\\\"OpenAI GPT-2 ({sz}) checkpoint val loss\\\")\\n\",\n    \"plt.xlabel(\\\"steps\\\")\\n\",\n    \"plt.ylabel(\\\"loss\\\")\\n\",\n    \"plt.yscale('log')\\n\",\n    \"plt.ylim(top=4.0)\\n\",\n    \"plt.legend()\\n\",\n    \"plt.title(\\\"Loss\\\")\\n\",\n    \"print(\\\"Min Validation Loss:\\\", min(ys))\\n\",\n    \"\\n\",\n    \"# Panel 2: HellaSwag eval\\n\",\n    \"plt.subplot(122)\\n\",\n    \"if \\\"eval\\\" in streams:\\n\",\n    \"    xs, ys = streams[\\\"eval\\\"] # HellaSwag eval\\n\",\n    \"    ys = np.array(ys)\\n\",\n    \"    plt.plot(xs, ys, label=f\\\"llm.c ({sz})\\\")\\n\",\n    \"    # horizontal line at GPT-2/3 baselines\\n\",\n    \"    if hella2_baseline:\\n\",\n    \"        plt.axhline(y=hella2_baseline, color='r', linestyle='--', label=f\\\"OpenAI GPT-2 ({sz}) checkpoint\\\")\\n\",\n    \"    if hella3_baseline:\\n\",\n    \"        plt.axhline(y=hella3_baseline, color='g', linestyle='--', label=f\\\"OpenAI GPT-3 ({sz}) checkpoint\\\")\\n\",\n    \"    plt.xlabel(\\\"steps\\\")\\n\",\n    \"    plt.ylabel(\\\"accuracy\\\")\\n\",\n    \"    plt.legend()\\n\",\n    \"    plt.title(\\\"HellaSwag eval\\\")\\n\",\n    \"    print(\\\"Max Hellaswag eval:\\\", max(ys))\\n\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"pytorch3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.10.14\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "doc/layernorm/layernorm.c",
    "content": "// must run `python layernorm.py` first to generate the reference data\n// then compile for example as `gcc layernorm.c -o layernorm -lm`\n// and then run as `./layernorm` to see the output\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <math.h>\n\nvoid layernorm_forward(float* out, float* mean, float* rstd,\n                       float* inp, float* weight, float* bias,\n                       int B, int T, int C) {\n    float eps = 1e-5f;\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            // seek to the input position inp[b,t,:]\n            float* x = inp + b * T * C + t * C;\n            // calculate the mean\n            float m = 0.0f;\n            for (int i = 0; i < C; i++) {\n                m += x[i];\n            }\n            m = m/C;\n            // calculate the variance (without any bias correction)\n            float v = 0.0f;\n            for (int i = 0; i < C; i++) {\n                float xshift = x[i] - m;\n                v += xshift * xshift;\n            }\n            v = v/C;\n            // calculate the rstd\n            float s = 1.0f / sqrtf(v + eps);\n            // seek to the output position in out[b,t,:]\n            float* out_bt = out + b * T * C + t * C;\n            for (int i = 0; i < C; i++) {\n                float n = (s * (x[i] - m)); // normalized output\n                float o = n * weight[i] + bias[i]; // scale and shift it\n                out_bt[i] = o; // write\n            }\n            // cache the mean and rstd for the backward pass later\n            mean[b * T + t] = m;\n            rstd[b * T + t] = s;\n        }\n    }\n}\n\nvoid layernorm_backward(float* dinp, float* dweight, float* dbias,\n                        float* dout, float* inp, float* weight, float* mean, float* rstd,\n                        int B, int T, int C) {\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            float* dout_bt = dout + b * T * C + t * C;\n            float* inp_bt = inp + b * T * C + t * C;\n            float* dinp_bt = dinp + b * T * C + t * C;\n            float mean_bt = mean[b * T + t];\n            float rstd_bt = rstd[b * T + t];\n\n            // first: two reduce operations\n            float dnorm_mean = 0.0f;\n            float dnorm_norm_mean = 0.0f;\n            for (int i = 0; i < C; i++) {\n                float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;\n                float dnorm_i = weight[i] * dout_bt[i];\n                dnorm_mean += dnorm_i;\n                dnorm_norm_mean += dnorm_i * norm_bti;\n            }\n            dnorm_mean = dnorm_mean / C;\n            dnorm_norm_mean = dnorm_norm_mean / C;\n\n            // now iterate again and accumulate all the gradients\n            for (int i = 0; i < C; i++) {\n                float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;\n                float dnorm_i = weight[i] * dout_bt[i];\n                // gradient contribution to bias\n                dbias[i] += dout_bt[i];\n                // gradient contribution to weight\n                dweight[i] += norm_bti * dout_bt[i];\n                // gradient contribution to input\n                float dval = 0.0f;\n                dval += dnorm_i; // term 1\n                dval -= dnorm_mean; // term 2\n                dval -= norm_bti * dnorm_norm_mean; // term 3\n                dval *= rstd_bt; // final scale\n                dinp_bt[i] += dval;\n            }\n        }\n    }\n}\n\n// poor man's tensor checker\nint check_tensor(float *a, float *b, int n, char* label) {\n    int ok = 1;\n    printf(\"%s\\n\", label);\n    for (int i = 0; i < n; i++) {\n        if (fabs(a[i] - b[i]) <= 1e-5) {\n            printf(\"OK \");\n        } else {\n            printf(\"NOT OK \");\n            ok = 0;\n        }\n        printf(\"%f %f\\n\", a[i], b[i]);\n    }\n    return ok;\n}\n\nint main() {\n\n    int B = 2; // batch\n    int T = 3; // time / sequence length\n    int C = 4; // number of channels\n\n    float* x = (float*) malloc(B * T * C * sizeof(float));\n    float* w = (float*) malloc(C * sizeof(float));\n    float* b = (float*) malloc(C * sizeof(float));\n    float* out = (float*) malloc(B * T * C * sizeof(float));\n    float* mean = (float*) malloc(B * T * sizeof(float));\n    float* rstd = (float*) malloc(B * T * sizeof(float));\n    float* dout = (float*) malloc(B * T * C * sizeof(float));\n    float* dx = (float*) malloc(B * T * C * sizeof(float));\n    float* dw = (float*) malloc(C * sizeof(float));\n    float* db = (float*) malloc(C * sizeof(float));\n\n    // read reference information from Python\n    FILE *file = fopen(\"ln.bin\", \"rb\");\n    if (file == NULL) {\n        printf(\"Error opening file\\n\");\n        return 1;\n    }\n    fread(x, sizeof(float), B * T * C, file);\n    fread(w, sizeof(float), C, file);\n    fread(b, sizeof(float), C, file);\n    fread(out, sizeof(float), B * T * C, file);\n    fread(mean, sizeof(float), B * T, file);\n    fread(rstd, sizeof(float), B * T, file);\n    fread(dout, sizeof(float), B * T * C, file);\n    fread(dx, sizeof(float), B * T * C, file);\n    fread(dw, sizeof(float), C, file);\n    fread(db, sizeof(float), C, file);\n    fclose(file);\n\n    // now let's calculate everything ourselves\n\n    // forward pass\n    float* c_out = (float*) malloc(B * T * C * sizeof(float));\n    float* c_mean = (float*) malloc(B * T * sizeof(float));\n    float* c_rstd = (float*) malloc(B * T * sizeof(float));\n    layernorm_forward(c_out, c_mean, c_rstd, x, w, b, B, T, C);\n\n    // check correctness of forward pass\n    check_tensor(out, c_out, B*T*C, \"out\");\n    check_tensor(mean, c_mean, B*T, \"mean\");\n    check_tensor(rstd, c_rstd, B*T, \"rstd\");\n\n    // backward pass (note calloc inits grads to zero)\n    float* c_dx = (float*) calloc(B * T * C, sizeof(float));\n    float* c_dw = (float*) calloc(B * T, sizeof(float));\n    float* c_db = (float*) calloc(B * T, sizeof(float));\n    layernorm_backward(c_dx, c_dw, c_db, dout, x, w, c_mean, c_rstd, B, T, C);\n\n    // check correctness of backward pass\n    check_tensor(c_dx, dx, B*T*C, \"dx\");\n    check_tensor(c_dw, dw, C, \"dw\");\n    check_tensor(c_db, db, C, \"db\");\n\n    free(x);\n    free(w);\n    free(b);\n    free(out);\n    free(mean);\n    free(rstd);\n    free(dout);\n    free(dx);\n    free(dw);\n    free(db);\n    return 0;\n}\n"
  },
  {
    "path": "doc/layernorm/layernorm.md",
    "content": "\n# layernorm\n\nQuick tutorial. Let's look at how LayerNorm is handled, as one example layer in the model. We start with the [PyTorch docs for LayerNorm](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html). LayerNorm of course comes from this original paper by [Ba et al. 2016](https://arxiv.org/abs/1607.06450), and was incorporated into the Transformer in [Vaswani et al.](https://arxiv.org/abs/1706.03762) famous paper Attention is All You Need. [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) picked up the same architecture as the Transformer, but the position of the LayerNorm was famously moved into what is now called the pre-normalization version. That is, the residual path of the Transformer is kept clean, and the LayerNorms are now the first layer of each block of the Transformer. This positively improves training stability.\n\nThe first thing to note when looking at [PyTorch LayerNorm](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html) is that you will most likely not be able to find the actual implementation of the equation. That's because it is buried 30 layers deep in the code, behind an inscrutable dynamical dispatcher, in some possibly auto-generated CUDA code (for those who are interested in details, see [layer_norm.cpp](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/layer_norm.cpp) and  [layer_norm_kernel.cu](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/layer_norm_kernel.cu)). This is done because PyTorch really really cares about efficiency, fair enough. For our purposes though, we have to start by first implementing LayerNorm manually using simpler PyTorch operations. This will be a lot less efficient than just forwarding a `LayerNorm` module, but it is algorithmically instructive. So here is the direct implementation of the math of LayerNorm using simpler PyTorch operations:\n\n```python\nimport torch\neps = 1e-5\n\nclass LayerNorm:\n\n    @staticmethod\n    def forward(x, w, b):\n        # x is the input activations, of shape B,T,C\n        # w are the weights, of shape C\n        # b are the biases, of shape C\n        B, T, C = x.size()\n        # calculate the mean\n        mean = x.sum(-1, keepdim=True) / C # B,T,1\n        # calculate the variance\n        xshift = x - mean # B,T,C\n        var = (xshift**2).sum(-1, keepdim=True) / C # B,T,1\n        # calculate the inverse standard deviation: **0.5 is sqrt, **-0.5 is 1/sqrt\n        rstd = (var + eps) ** -0.5 # B,T,1\n        # normalize the input activations\n        norm = xshift * rstd # B,T,C\n        # scale and shift the normalized activations at the end\n        out = norm * w + b # B,T,C\n\n        # return the output and the cache, of variables needed later during the backward pass\n        cache = (x, w, mean, rstd)\n        return out, cache\n```\n\nThe activation tensors in the residual path of the Transformer during training are 3-dimensional arrays (tensors), of shape `B,T,C`. B is the batch size, T is time, and C is channels. For example, B=8, T=1024, C=768 is one setting you might see, for the smallest (124 million parameter) GPT-2 model.\n\nWe can forward this layer with some random numbers:\n\n```python\nB = 2 # some toy numbers here\nT = 3\nC = 4\nx = torch.randn(B, T, C, requires_grad=True)\nw = torch.randn(C, requires_grad=True)\nb = torch.randn(C, requires_grad=True)\nout, cache = LayerNorm.forward(x, w, b)\n```\n\nWhat we get out is the tensor `out`, also of shape `B,T,C`, where each C-dimensional \"fibre\" of activations (as we call them) is normalized and then scaled and at the end also shifted by the weights and biases of this layer. Notice that, importantly, we also return a variable `cache`, which is a tuple of the input activations `x`, the weights `w`, the mean `mean`, and the reciprocal standard deviation `rstd`. These are all variables we need during the backward pass.\n\nPyTorch can of course do the backward pass of this layer for us with its Autograd. Let's do that first:\n\n```python\ndout = torch.randn(B, T, C)\nfakeloss = (out * dout).sum()\nfakeloss.backward()\n```\n\nYou see here that we created a `fakeloss`, which simply takes a (random) weighted combination of all the outputs of our layernorm. All this is doing is projecting all of the `B,T,C` numbers into a single scalar value (loss), so that we have a single output of our \"computational graph\". Typically this would be the loss of the model, but here we're just doing a fake loss. We then call `backward()` on this scalar, and PyTorch will compute all the gradients for us on all the inputs to this graph - i.e. the input activations `x`, the weights `w`, and the biases `b`. If you don't know too much about autograd, I'd encourage you to watch my [micrograd](https://www.youtube.com/watch?v=VMj-3S1tku0) video, where we build a tiny autograd engine. So the magic of PyTorch autograd is that after we call `.backward`, it will populate the `.grad` attribute of all the tensors that have `requires_grad=True` with the gradients of the loss with respect to that tensor. These gradients are telling us the slope of the loss for all of the input numbers in x,w,b. Therefore, the shape of `x.grad`, `w.grad`, and `b.grad` are exactly the same as the shape of `x`, `w`, and `b`.\n\nBut we don't want to use PyTorch Autograd. We want to do the backward pass manually. So we take out pen and paper and write out the expression for LayerNorm. The forward pass has the following mathematical form:\n\n$\\text{LayerNorm}(x) = w \\odot \\frac{x - \\mu}{\\sqrt{\\sigma^2 + \\epsilon}} + b$\n\nwhere $\\odot$ is elementwise multiplication, $\\mu$ is the mean, $\\sigma^2$ is the variance, and $\\epsilon$ is a small constant to avoid division by zero. Remembering the rules of differentiation from calculus, we now want to derive the gradients. For this part, my video [Becoming a Backprop Ninja](https://www.youtube.com/watch?v=q8SA3rM6ckI) could be very helpful, as I work through (in detail) a similar layer - the Batch Normalization layer. When you work through the differentiation, you'll notice that the expressions simplify analytically and you can move the terms around and simplify the expression somehwat. So you don't have to manually backward every individual line in the forward pass. In particular, we get:\n\n```python\n    @staticmethod\n    def backward(dout, cache):\n        x, w, mean, rstd = cache\n        # recompute the norm (save memory at the cost of compute)\n        norm = (x - mean) * rstd\n        # gradients for weights, bias\n        db = dout.sum((0, 1))\n        dw = (dout * norm).sum((0, 1))\n        # gradients for input\n        dnorm = dout * w\n        dx = dnorm - dnorm.mean(-1, keepdim=True) - norm * (dnorm * norm).mean(-1, keepdim=True)\n        dx *= rstd\n        return dx, dw, db\n```\n\nSo given the gradients on every individual output number stored in `dout`, and the `cache` from the forward pass, we can now backward through this layer into the inputs, to continue the chain rule of the backward pass. So now we can do our own backward pass and see that they match (the errors are tiny):\n\n```python\ndx, dw, db = LayerNorm.backward(dout, cache)\nprint(\"dx error:\", (x.grad - dx).abs().max().item())\nprint(\"dw error:\", (w.grad - dw).abs().max().item())\nprint(\"db error:\", (b.grad - db).abs().max().item())\n```\n\nNotice one more thing. Inside the backward pass we recomputed the variable `norm`. We already calculated this variable in the forward pass but then we threw it away! Couldn't we have made this also be a part of the `cache` and save this recompute? Actually, we very well could and you'd of course get the exact same results. The amount of stuff we save into our `cache` is completely up to us. We didn't even have to save `mean` and `rstd` either, and we could have recomputed them in the backward pass. The difference is that `mean` and `rstd` are very small, only of shape `B,T`, where as `norm` is of shape `B,T,C`. So this is simply a tradeoff between memory and compute. By not keeping `norm` in the cache, we are saving memory, but we are trading it off for a bit of compute later in the backward pass. This is very common in all the layers, and you'll see that different implementations of various layers in deep learning frameworks may all have different \"checkpointing settings\". Yes, confusingly enough, this is called checkpointing and has nothing to do with saving the model weights to disk. It's about saving intermediate variables in the forward pass to save compute in the backward pass.\n\nOkay so that's the version with PyTorch tensors. Now we have to move this to C and get rid of the Tensor abstraction. Before I give you the full implementation of the forward pass, a brief word on Tensors. What are Tensors? They are 1) a 1D block of memory called Storage that holds the raw data, and 2) a View over that storage that holds its shape. [PyTorch Internals](http://blog.ezyang.com/2019/05/pytorch-internals/) could be helpful here. So for example if we have the 3D tensor:\n\n```python\ntorch.manual_seed(42)\nB, T, C = 2, 3, 4\na = torch.randn(B, T, C)\nprint(a)\n\ntensor([[[ 1.9269,  1.4873,  0.9007, -2.1055],\n         [ 0.6784, -1.2345, -0.0431, -1.6047],\n         [ 0.3559, -0.6866, -0.4934,  0.2415]],\n\n        [[-1.1109,  0.0915, -2.3169, -0.2168],\n         [-0.3097, -0.3957,  0.8034, -0.6216],\n         [-0.5920, -0.0631, -0.8286,  0.3309]]])\n```\n\nThis is 2x3x4 Tensor, but the underlying memory of it is just one single 1D array of size 2\\*3\\*4=24. The View is just a shape over this 1D array. So now when we index into this PyTorch tensor, for example `a[1,2,3]`, PyTorch computes the offset into the 1D array as `1*3*4 + 2*4 + 3 = 23`, and return the value at that offset. The general formula is that if you want to retrieve any element `b,t,c`, you compute the offset into Storage as `b*T*C + t*C + c`. So for example:\n\n```python\nb,t,c = 1,2,3\nprint(a[b,t,c])\nprint(a.view(-1)[b*T*C + t*C + c])\n```\n\nBoth of these print 0.3309. So in this way, we know how to access all the individual elements, and how to offset all the pointers. Notice in particular that the channel dimension is the innermost dimension. So as we increase offset by 1, we are traversing the channel dimension. This is important to consider for the memory layout of our C implementation. The equivalent forward pass in C becomes:\n\n```c\n#include <stdio.h>\n#include <stdlib.h>\n#include <math.h>\n\nvoid layernorm_forward(float* out, float* mean, float* rstd,\n                       float* inp, float* weight, float* bias,\n                       int B, int T, int C) {\n    float eps = 1e-5f;\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            // seek to the input position inp[b,t,:]\n            float* x = inp + b * T * C + t * C;\n            // calculate the mean\n            float m = 0.0f;\n            for (int i = 0; i < C; i++) {\n                m += x[i];\n            }\n            m = m/C;\n            // calculate the variance (without any bias correction)\n            float v = 0.0f;\n            for (int i = 0; i < C; i++) {\n                float xshift = x[i] - m;\n                v += xshift * xshift;\n            }\n            v = v/C;\n            // calculate the rstd\n            float s = 1.0f / sqrtf(v + eps);\n            // seek to the output position in out[b,t,:]\n            float* out_bt = out + b * T * C + t * C;\n            for (int i = 0; i < C; i++) {\n                float n = (s * (x[i] - m)); // normalized output\n                float o = n * weight[i] + bias[i]; // scale and shift it\n                out_bt[i] = o; // write\n            }\n            // cache the mean and rstd for the backward pass later\n            mean[b * T + t] = m;\n            rstd[b * T + t] = s;\n        }\n    }\n}\n```\n\nYou'll see how I offset the pointer to the `inp[b,t]`, and then you know that the next `C` elements are the channels of that position in (batch, time). And the backward pass:\n\n```c\nvoid layernorm_backward(float* dinp, float* dweight, float* dbias,\n                        float* dout, float* inp, float* weight, float* mean, float* rstd,\n                        int B, int T, int C) {\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            float* dout_bt = dout + b * T * C + t * C;\n            float* inp_bt = inp + b * T * C + t * C;\n            float* dinp_bt = dinp + b * T * C + t * C;\n            float mean_bt = mean[b * T + t];\n            float rstd_bt = rstd[b * T + t];\n\n            // first: two reduce operations\n            float dnorm_mean = 0.0f;\n            float dnorm_norm_mean = 0.0f;\n            for (int i = 0; i < C; i++) {\n                float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;\n                float dnorm_i = weight[i] * dout_bt[i];\n                dnorm_mean += dnorm_i;\n                dnorm_norm_mean += dnorm_i * norm_bti;\n            }\n            dnorm_mean = dnorm_mean / C;\n            dnorm_norm_mean = dnorm_norm_mean / C;\n\n            // now iterate again and accumulate all the gradients\n            for (int i = 0; i < C; i++) {\n                float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;\n                float dnorm_i = weight[i] * dout_bt[i];\n                // gradient contribution to bias\n                dbias[i] += dout_bt[i];\n                // gradient contribution to weight\n                dweight[i] += norm_bti * dout_bt[i];\n                // gradient contribution to input\n                float dval = 0.0f;\n                dval += dnorm_i; // term 1\n                dval -= dnorm_mean; // term 2\n                dval -= norm_bti * dnorm_norm_mean; // term 3\n                dval *= rstd_bt; // final scale\n                dinp_bt[i] += dval;\n            }\n        }\n    }\n}\n```\n\nOne additional detail to note is that we always += into the gradients. We never use = and we never use *=. This is important stylistically because if you have one variable used multiple times in a graph, the backward pass gradients always add up. In this repo this is not important because we don't have exotic branching, but it's proper. So during training we always first do `zero_grad` to set all the gradients to zero, and then we accumulate into them during backward pass.\n\nOne more note on differences between training and inference. Some of you may have already seen my earlier project [llama2.c](https://github.com/karpathy/llama2.c), which inferences Llama 2 architecture in pure C. Unlike GPT-2, Llama 2 swaps out LayerNorm for the much simpler RMSNorm. You can see the implementation of the [RMSNorm in llama2.c](https://github.com/karpathy/llama2.c/blob/master/run.c#L182), copy pasting it here:\n\n```c\nvoid rmsnorm(float* o, float* x, float* weight, int size) {\n    // calculate sum of squares\n    float ss = 0.0f;\n    for (int j = 0; j < size; j++) {\n        ss += x[j] * x[j];\n    }\n    ss /= size;\n    ss += 1e-5f;\n    ss = 1.0f / sqrtf(ss);\n    // normalize and scale\n    for (int j = 0; j < size; j++) {\n        o[j] = weight[j] * (ss * x[j]);\n    }\n}\n```\n\nHow does this differ to our LayerNorm above?\n\n- First, algorithmically, you'll notice that RMSNorm does not keep track of or subtract the mean, it only normalizes by the norm. Notice: norm, not standard deviation, because we did not subtract the mean. This is a simplification of the layer that has now become very trendy because it works just as well, if not slightly better. Also, the RMSNorm does not have biases, it only has a weight for scaling after normalization. In general, GPT-2 used way too many biases everywhere and it turns out you can remove these - from all the Linear Layers and from LayerNorms. The network can \"simulate\" biases if it needs them, e.g. by allocating one of the channel dimensions to be constant (data-independent), and then any weight multiplying that constant dimension will effectively work like a bias. This significantly simplies a lot of the code.\n- Second, the inference code has no batch dimension B, i.e. the batch size is assumed to be 1. You could in principle have batched inference as well, especially if you wish to host an LLM that you expect many simultaneous queries to. But if you're just running an LLM locally, chances are you just want to have a single \"stream\" of generation, so there is no batch size for parallelism that could support multiple streams at once. To keep things simple, llama2.c is not batched, and therefore you won't see any loops that look like `for (int b = 0; b < B; b++)`.\n- Third, this inference code has no time dimension T within this individual layer. During training, we can loop over time inside each layer and calculate the layernorm at all time steps. But during inference, we have to generate one token at a time, feeding the token predicted at time `t` into the forward pass of the Transformer at the next time step `t+1`. So this is why you don't see any loops that look like `for (int t = 0; t < T; t++)` inside individual layers. This loop over time [does exist](https://github.com/karpathy/llama2.c/blob/master/run.c#L747), but it is on the outside of the Transformer forward pass.\n- You'll see that we don't keep track of any intermediate calculations, memory, or cache. That's because during inference, there is no `.backward` pass that will follow. We only need to calculate the output, and we don't need to keep any intermediate variables around. As a result, the memory consumption of inference is significantly lower than that of training. We can afford to just discard activations, and only keep memory for the \"activation frontier\". Similarly, there is no need to implement the `backward` function for this RMSNorm anywhere, as there is no backward pass.\n\nAs a result of all these difference, training is significantly more complex and involved, both algorithmically and computationally, and that's partly why I started by writing inference (llama2.c) before I implemented training (llm.c, here). Finally, I am attaching two helper files to this same directory that have the complete code. First:\n\n```\npython layernorm.py\n```\n\nTo write out the reference data from PyTorch. Then compile and run the C version:\n\n```\ngcc layernorm.c -o layernorm -lm\n./layernorm\n```\n\nYou'll see that everything matches ok.\n\nThis was just the LayerNorm. We go through the exact same process for all the other layers. Most of the other layers are actually easier than LayerNorm. Hope that helps!\n"
  },
  {
    "path": "doc/layernorm/layernorm.py",
    "content": "import torch\n\neps = 1e-5\n\nclass LayerNorm:\n\n    @staticmethod\n    def forward(x, w, b):\n        B, T, C = x.size()\n        mean = x.sum(-1, keepdim=True) / C # B,T,1\n        xshift = x - mean # B,T,C\n        var = (xshift**2).sum(-1, keepdim=True) / C # B,T,1\n        rstd = (var + eps) ** -0.5 # B,T,1\n        norm = xshift * rstd # B,T,C\n        out = norm * w + b # B,T,C\n\n        cache = (x, w, mean, rstd)\n        return out, cache\n\n    @staticmethod\n    def backward(dout, cache):\n        x, w, mean, rstd = cache\n        # recompute the norm (save memory at the cost of compute)\n        norm = (x - mean) * rstd\n        # gradients for weights, bias\n        db = dout.sum((0, 1))\n        dw = (dout * norm).sum((0, 1))\n        # gradients for input\n        dnorm = dout * w\n        dx = dnorm - dnorm.mean(-1, keepdim=True) - norm * (dnorm * norm).mean(-1, keepdim=True)\n        dx *= rstd\n        return dx, dw, db\n\n# create a small dummy example and check w.r.t PyTorch backward\nB = 2\nT = 3\nC = 4\nx = torch.randn(B, T, C, requires_grad=True)\nw = torch.randn(C, requires_grad=True)\nb = torch.randn(C, requires_grad=True)\nout, cache = LayerNorm.forward(x, w, b)\n\ndout = torch.randn(B, T, C)\ndx, dw, db = LayerNorm.backward(dout, cache)\n\n# compare to PyTorch autograd\nfakeloss = (out * dout).sum()\nfakeloss.backward()\nprint(\"dx error:\", (x.grad - dx).abs().max().item())\nprint(\"dw error:\", (w.grad - dw).abs().max().item())\nprint(\"db error:\", (b.grad - db).abs().max().item())\n\n# for reference checking in C also\nx, w, mean, rstd = cache\n\ndef write(tensor, handle):\n    handle.write(tensor.detach().numpy().astype(\"float32\").tobytes())\n\n# Write to file\nwith open('ln.bin', 'wb') as file:\n    write(x, file) # (B, T, C)\n    write(w, file) # (C, )\n    write(b, file) # (C, )\n    write(out, file) # (B, T, C)\n    write(mean, file) # (B, T)\n    write(rstd, file) # (B, T)\n    write(dout, file) # (B, T, C)\n    write(dx, file) # (B, T, C)\n    write(dw, file) # (C, )\n    write(db, file) # (C, )\n"
  },
  {
    "path": "llmc/adamw.cuh",
    "content": "/*\nAdamW kernel\n*/\n\n// llmc internal imports\n#include \"cuda_common.h\"\n#include \"cuda_utils.cuh\"\n\n// ----------------------------------------------------------------------------\n// CUDA kernels\n\n// Implements linear interpolation using only two floating-point operations (as opposed to three in a naive implementation).\n// Reference: https://developer.nvidia.com/blog/lerp-faster-cuda\n__device__ float lerp(float start, float end, float weight) {\n    return fma(weight, end, fma(-weight, start, start));\n}\n\ntemplate <typename Tp, typename Tg>\n__device__ void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,\n                             float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay,\n                             float grad_scale, unsigned int seed) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx >= num_parameters) { return; }  // guard\n\n    // get the gradient, m, and v for this parameter\n    float grad = grad_scale * (float)grads_memory[idx];\n    float m = m_memory[idx];\n    float v = v_memory[idx];\n    // update the first moment (momentum)\n    m = lerp(grad, m, beta1);\n    m_memory[idx] = m;\n    // update the second moment (RMSprop)\n    v = lerp(grad * grad, v, beta2);\n    v_memory[idx] = v;\n    m /= beta1_correction;  // m_hat\n    v /= beta2_correction;  // v_hat\n    // fetch the old value of this parameter as a float, from either source\n    float old_param = (master_params_memory != NULL) ? master_params_memory[idx] : (float)params_memory[idx];\n    // update this parameter\n    float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + weight_decay * old_param));\n    // update our low precision version of the parameters using stochastic rounding\n    // this will be used in the next forward pass\n    stochastic_rounding(param, &params_memory[idx], seed);\n    // write the full, float version of the param into our master copy, if we maintain one\n    // this will be used in the next update\n    if (master_params_memory != NULL) { master_params_memory[idx] = param; }\n}\n\ntemplate <typename Tp, typename Tg>\n__global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,\n                              ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride,\n                              float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay,\n                              float grad_scale, unsigned int seed) {\n    adamw_update(params_memory + blockIdx.y * w_stride,\n                 master_params_memory ? master_params_memory + blockIdx.y * s_stride : NULL,\n                 grads_memory + blockIdx.y * g_stride,\n                 m_memory + blockIdx.y * s_stride,\n                 v_memory + blockIdx.y * s_stride,\n                 num_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale,\n                 seed\n                 );\n}\n\ntemplate <typename Tp>\n__global__ void init_from_master_kernel(Tp* params_memory, float* master_params_memory, size_t num_parameters,\n                                          ptrdiff_t w_stride, ptrdiff_t s_stride, unsigned int seed) {\n    size_t idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx >= num_parameters) { return; }\n    params_memory += blockIdx.y * w_stride; // adjust for layer offset\n    master_params_memory += blockIdx.y * s_stride;\n    stochastic_rounding(master_params_memory[idx], &params_memory[idx], seed);\n}\n\ntemplate <typename Tp, typename Tg>\nvoid adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,\n                  ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride,  int num_slices, float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay,\n                  float grad_scale, unsigned int seed, cudaStream_t stream) {\n    // AdamW update\n    int block_size = 512;\n    int num_blocks = CEIL_DIV(num_parameters, block_size);\n    float beta1_correction = 1.0f - powf(beta1, t);\n    float beta2_correction = 1.0f - powf(beta2, t);\n    adamw_kernel3<<<dim3(num_blocks, num_slices), block_size, 0, stream>>>(params_memory, master_params_memory, grads_memory,\n                                                         m_memory, v_memory, num_parameters, w_stride, g_stride, s_stride,\n                                                         learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay,\n                                                         grad_scale, seed);\n    cudaCheck(cudaGetLastError());\n}\n\ntemplate <typename Tp>\nvoid init_from_master(Tp* params_memory, float* master_params_memory, size_t num_parameters,\n                        ptrdiff_t w_stride, ptrdiff_t s_stride, int num_slices, unsigned int seed, cudaStream_t stream) {\n    int block_size = 512; // must match block size of adamw_update so that RNG also matches\n    int num_blocks = CEIL_DIV(num_parameters, block_size);\n    init_from_master_kernel<<<dim3(num_blocks, num_slices), block_size, 0, stream>>>\n                             (params_memory, master_params_memory, num_parameters, w_stride, s_stride, seed);\n    cudaCheck(cudaGetLastError());\n}\n"
  },
  {
    "path": "llmc/attention.cuh",
    "content": "/*\nAttention, as a fallback when we do not use the Flash Attention from cuDNN\n*/\n#include <assert.h>\n// llmc internal imports\n#include \"cuda_common.h\"\n#include \"cuda_utils.cuh\"\n#include \"cublas_common.h\"\n\n// ----------------------------------------------------------------------------\n// CUDA kernels\n\n// inputs floatX, outputs FP32 (for current FP32-only activation path for this WIP)\n__global__ void permute_kernel(floatX* q, floatX* k, floatX* v,\n                               const floatX* inp,\n                               int B, int N, int NH, int d) {\n    // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d)\n    // but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d)\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx >= B * NH * N * d) { return; }\n\n    // Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_]\n    int b = idx / (NH * N * d);\n    int rest = idx % (NH * N * d);\n    int nh_ = rest / (N * d);\n    rest = rest % (N * d);\n    int n = rest / d;\n    int d_ = rest % d;\n    int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_;\n    q[idx] = __ldcs(&inp[inp_idx]);\n    k[idx] = __ldcs(&inp[inp_idx + NH * d]);\n    v[idx] = __ldcs(&inp[inp_idx + 2 * (NH * d)]);\n}\n\n__global__ void permute_kernel_backward(floatX* dinp,\n                                        const floatX* dq, const floatX* dk, const floatX* dv,\n                                        int B, int N, int NH, int d) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx >= B * NH * N * d) { return; }\n\n    int b = idx / (NH * N * d);\n    int rest = idx % (NH * N * d);\n    int nh_ = rest / (N * d);\n    rest = rest % (N * d);\n    int n = rest / d;\n    int d_ = rest % d;\n\n    int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_;\n    dinp[inp_idx] = dq[idx];\n    dinp[inp_idx + NH * d] = dk[idx];\n    dinp[inp_idx + 2 * (NH * d)] = dv[idx];\n}\n\n__global__ void unpermute_kernel(floatX* inp, floatX *out, int B, int N, int NH, int d) {\n   // out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d)\n\n    int idx = (blockIdx.x * blockDim.x + threadIdx.x);\n    // out[b][n][nh_][d_] <- inp[b][nh_][n][d_]\n    if (idx >= B * NH * N * d) { return; }\n\n    int b = idx / (NH * N * d);\n    int rest = idx % (NH * N * d);\n    int nh_ = rest / (N * d);\n    rest = rest % (N * d);\n    int n = rest / d;\n    int d_ = rest % d;\n    int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;\n    out[other_idx] = __ldcs(&inp[idx]);\n}\n\n__global__ void unpermute_kernel_backward(floatX* dinp, const floatX *dout, int B, int N, int NH, int d) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx >= B * NH * N * d) { return; }\n\n    int b = idx / (NH * N * d);\n    int rest = idx % (NH * N * d);\n    int nh_ = rest / (N * d);\n    rest = rest % (N * d);\n    int n = rest / d;\n    int d_ = rest % d;\n    int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;\n    dinp[idx] = (floatX)dout[other_idx];\n}\n\n__global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, const floatX* inp, int N, int T) {\n    // inp, out shape: (N, T, T), where N = B * NH\n    // fuses the multiplication by scale inside attention\n    // directly autoregressive, so we only compute the lower triangular part\n    // uses the online softmax algorithm\n    assert(T % 4  == 0);\n    int lane_id = threadIdx.x % WARP_SIZE;\n    int warp_id = threadIdx.x / WARP_SIZE;\n    int num_warps = blockDim.x / WARP_SIZE;\n\n    // micro-optimization: we iterate backwards so that\n    // after the softmax backward operation completes, the cache retains the\n    // part of the matrix close to the upper left corner, which benefits the\n    // matmul operation that immediately follows.\n    // int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); // forward order\n    int idx = (gridDim.x - blockIdx.x - 1) * num_warps + warp_id; // backward order\n    if(idx >= N * T) {\n        return;\n    }\n    int own_pos = idx % T;\n    int pos_by_4 = own_pos / 4;\n\n    // one row of inp, i.e. inp[idx, :] of shape (T,)\n    const floatX* x = inp + idx * T;\n\n    // not INF, so we don't get NaNs accidentally when subtracting two values.\n    const float flt_max = 340282346638528859811704183484516925440.0f; // to avoid including float.h\n    float maxval = -flt_max;\n    float sumval = 0.0f;\n\n    const floatX* x_aligned = reinterpret_cast<const floatX*>(__builtin_assume_aligned(x, 16));\n    for (int i = lane_id; i < pos_by_4; i += WARP_SIZE) {\n        float regarray[4];\n        for (int k = 0; k < 4; ++k) {\n            regarray[k] = (float)x_aligned[4*i + k];\n        }\n        float old_maxval = maxval;\n        for(int k = 0; k < 4; ++k) {\n            maxval = fmaxf(maxval, regarray[k]);\n        }\n        sumval *= expf(inv_temperature * (old_maxval - maxval));\n        for(int k = 0; k < 4; ++k) {\n            sumval += expf(inv_temperature * (regarray[k] - maxval));\n        }\n    }\n\n    if(4*pos_by_4 + lane_id <= own_pos) {\n        float old_maxval = maxval;\n        maxval = fmaxf(maxval, (float)x[4*pos_by_4 + lane_id]);\n        sumval *= expf(inv_temperature * (old_maxval - maxval));\n        sumval += expf(inv_temperature * ((float)x[4*pos_by_4 + lane_id] - maxval));\n    }\n\n    float global_maxval = warpReduceMax(maxval);\n    sumval *= expf(inv_temperature * (maxval - global_maxval));\n\n    float sum = warpReduceSum(sumval);\n    float norm = 1.f / sum;\n\n    // divide the whole row by the sum\n    for (int i = lane_id; i <= own_pos; i += WARP_SIZE) {\n        // recalculation is faster than doing the round-trip through memory.\n        float ev = expf(inv_temperature * ((float)__ldcs(x + i) - global_maxval));\n        __stcs(out + idx * T + i, (floatX)(ev * norm));\n    }\n}\n\n__global__ void softmax_autoregressive_backward_inplace_kernel(floatX* datt, const floatX* att,\n                                                               int B, int T, int C, float scale) {\n    constexpr const int BlockSize = 256;\n    constexpr int T_per_block = 4;\n\n    // go through blocks in reverse order, so the slowest block starts first\n    int t0 = T - 1 - T_per_block*blockIdx.x;\n    int idx = blockIdx.y;\n\n    att += idx * T * T;\n    datt += idx * T * T;\n\n    for(int to = 0; to < T_per_block; ++to) {\n        int t = t0 - to;\n        if(t < 0) return;\n        const floatX* att_bth = att + t * T;\n        const floatX* datt_bth = datt + t * T;\n        floatX* dpreatt_bth = datt + t * T;\n\n        float local_sum = 0;\n        for (int t2 = threadIdx.x; t2 <= t; t2 += BlockSize) {\n            local_sum += (float)att_bth[t2] * (float)datt_bth[t2];\n        }\n\n        local_sum = blockReduce<warpReduceSum>(local_sum);\n\n        for (int t3 = threadIdx.x; t3 < T; t3 += BlockSize) {\n            // don't touch the cache. Some parts will still be here from the previous loop, and\n            // we want to exploit those.\n            if(t3 <= t) {\n                float acc = (float) __ldcs(att_bth + t3) * ((float) __ldcs(datt_bth + t3) - local_sum);\n                __stcs(dpreatt_bth + t3, (floatX) (scale * acc));\n            } else {\n                // explicitly set non-causal elements to zero\n                __stcs(dpreatt_bth + t3, (floatX)0.f);\n            }\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launchers\n\nvoid attention_forward(floatX* out, floatX* qkvr, floatX* att,\n                       floatX* inp,\n                       int B, int T, int C, int NH, cudaStream_t stream) {\n    NVTX_RANGE_FN();\n    // Note: `inp` is not needed for backward pass, so we re-use it as a scratch buffer.\n    // Its contents will be overwritten by this function.\n    const int block_size = 256;\n\n    // inp is (B, T, 3C) QKV\n    // preatt, att are (B, NH, T, T)\n    // output is (B, T, C)\n    const int HS = C / NH; // head size\n\n    // permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS)\n    floatX *q, *k, *v;\n    q = qkvr + 0 * B * T * C;\n    k = qkvr + 1 * B * T * C;\n    v = qkvr + 2 * B * T * C;\n    int total_threads = B * NH * T * HS;\n    int num_blocks = CEIL_DIV(total_threads, block_size);\n    permute_kernel<<<num_blocks, block_size, 0, stream>>>(q, k, v, inp, B, T, NH, HS);\n\n    floatX* preatt = inp; // reuse inp as scratch buffer\n    matmul_cublaslt(preatt, k, q, nullptr, T, T, HS, stream, true, false, B * NH, T * HS, T * HS, T * T);\n\n    // multiply all elements of preatt elementwise by scale\n    float scale = 1.f / sqrtf(HS);\n    int grid_size = CEIL_DIV(B * NH * T * WARP_SIZE, block_size);\n    softmax_forward_kernel5<<<grid_size, block_size, 0, stream>>>(att, scale, preatt, B * NH, T);\n\n    // new approach: first cuBLAS another batched matmul\n    floatX* vaccum = inp;\n    // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)\n    matmul_cublaslt(vaccum, v, att, nullptr, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS);\n\n    // now unpermute\n    // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side\n    num_blocks = CEIL_DIV(B * T * C, block_size);\n    unpermute_kernel<<<num_blocks, block_size, 0, stream>>>(vaccum, out, B, T, NH, HS);\n    cudaCheck(cudaGetLastError());\n}\n\n// the sequence of transformations in this compound op is:\n// inp (B,T,3C) -> qkvr (B,T,3C) -> preatt (B,NH,T,T) -> att (B,NH,T,T) -> vaccum (B,T,C) -> out (B,T,C)\nvoid attention_backward(floatX* dinp, floatX* dqkvr, floatX* datt, floatX* scratch,\n                        const floatX* dout,\n                        const floatX* qkvr, const floatX* att,\n                        int B, int T, int C, int NH, cudaStream_t stream) {\n    NVTX_RANGE_FN();\n    const int block_size = 256;\n    const int HS = C / NH; // head size\n\n    // unpack convenience pointers into q, k, v\n    const floatX *q, *k, *v;\n    q = qkvr + 0 * B * T * C;\n    k = qkvr + 1 * B * T * C;\n    v = qkvr + 2 * B * T * C;\n    floatX *dq, *dk, *dv;\n    dq = dqkvr + 0 * B * T * C;\n    dk = dqkvr + 1 * B * T * C;\n    dv = dqkvr + 2 * B * T * C;\n\n    // backward through the unpermute operation\n    int num_blocks = CEIL_DIV(B * T * C, block_size);\n    unpermute_kernel_backward<<<num_blocks, block_size, 0, stream>>>(scratch, dout, B, T, NH, HS);\n    // backward into datt\n    matmul_cublaslt(datt, v, scratch, nullptr, T, T, HS, stream, true, false, B * NH, T * HS, T * HS, T * T);\n    // backward into dv\n    matmul_cublaslt(dv, scratch, att, nullptr, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS);\n    const float scale = 1.0f / sqrtf((float)HS);\n    // backward into preatt. this is an in-place operation; datt turns into dpreatt here\n    softmax_autoregressive_backward_inplace_kernel<<<dim3(T / 4, B * NH), 256>>>(datt, att, B, T, C, scale);\n    const floatX* dpreatt = datt;\n    // backward into q\n    matmul_cublaslt(dq, k, dpreatt, nullptr, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS);\n    // backward into k\n    matmul_cublaslt(dk, q, dpreatt, nullptr, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS);\n    // backward into inp\n    num_blocks = CEIL_DIV(B * NH * T * HS, block_size);\n    permute_kernel_backward<<<num_blocks, block_size, 0, stream>>>(dinp, dq, dk, dv, B, T, NH, HS);\n    cudaCheck(cudaGetLastError());\n}\n"
  },
  {
    "path": "llmc/cublas_common.h",
    "content": "/*\ncuBLAS related utils\n*/\n#ifndef CUBLAS_COMMON_H\n#define CUBLAS_COMMON_H\n\n#include <stddef.h>\n#include <stdlib.h>\n#include <stdio.h>\n#include <cublas_v2.h>\n#include <cublasLt.h>\n\n// ----------------------------------------------------------------------------\n// cuBLAS Precision settings\n\n#if defined(ENABLE_FP32)\n#define CUBLAS_LOWP CUDA_R_32F\n#elif defined(ENABLE_FP16)\n#define CUBLAS_LOWP CUDA_R_16F\n#else // default to bfloat16\n#define CUBLAS_LOWP CUDA_R_16BF\n#endif\n\n// ----------------------------------------------------------------------------\n// cuBLAS globals for workspace, handle, settings\n\n// Hardcoding workspace to 32MiB but only Hopper needs 32 (for others 4 is OK)\nconst size_t cublaslt_workspace_size = 32 * 1024 * 1024;\nvoid* cublaslt_workspace = NULL;\ncublasComputeType_t cublas_compute = CUBLAS_COMPUTE_32F;\ncublasLtHandle_t cublaslt_handle;\n\n// ----------------------------------------------------------------------------\n// Error checking\n\n// cuBLAS error checking\nvoid cublasCheck(cublasStatus_t status, const char *file, int line)\n{\n    if (status != CUBLAS_STATUS_SUCCESS) {\n        printf(\"[cuBLAS ERROR]: %d %s %d\\n\", status, file, line);\n        exit(EXIT_FAILURE);\n    }\n}\n#define cublasCheck(status) { cublasCheck((status), __FILE__, __LINE__); }\n\n#endif // CUBLAS_COMMON_H"
  },
  {
    "path": "llmc/cuda_common.h",
    "content": "/*\nCommon utilities for CUDA code.\n*/\n#ifndef CUDA_COMMON_H\n#define CUDA_COMMON_H\n\n#include <stdlib.h>\n#include <stdio.h>\n#include <math.h>\n#include <string>\n#include <type_traits>      // std::bool_constant\n#include <cuda_runtime.h>\n#include <nvtx3/nvToolsExt.h>\n#include <nvtx3/nvToolsExtCudaRt.h>\n#include <cuda_profiler_api.h>\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n\n#include \"utils.h\"\n\n// ----------------------------------------------------------------------------\n// Global defines and settings\n\n// Device properties of the CUDA device used in this process\n// defined as extern here because the individual kernels wish to use it\n// but it is actually created and instantiated in the main program file\nextern cudaDeviceProp deviceProp;\n\n// WarpSize is not a compile time constant\n// Defining here like this possibly allows the compiler to optimize better\n#define WARP_SIZE 32U\n\n// try to make sure that 2 blocks fit on A100/H100 to maximise latency tolerance\n// this needs to be defines rather than queried to be used for __launch_bounds__\n#if __CUDA_ARCH__ == 800 || __CUDA_ARCH__ >= 900\n#define MAX_1024_THREADS_BLOCKS 2\n#else\n#define MAX_1024_THREADS_BLOCKS 1\n#endif\n\n// convenience macro for calculating grid/block dimensions for kernels\n#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))\n\n// short-cuts for compile-time boolean values that can be used as function arguments\nconstexpr std::bool_constant<true> True;\nconstexpr std::bool_constant<true> False;\n\n// ----------------------------------------------------------------------------\n// Error checking\n\n// CUDA error checking. Underscore added so this function can be called directly not just via macro\ninline void cudaCheck_(cudaError_t error, const char *file, int line) {\n  if (error != cudaSuccess) {\n    printf(\"[CUDA ERROR] at file %s:%d:\\n%s\\n\", file, line, cudaGetErrorString(error));\n    exit(EXIT_FAILURE);\n  }\n};\n#define cudaCheck(err) (cudaCheck_(err, __FILE__, __LINE__))\n\n// like cudaFree, but checks for errors _and_ resets the pointer.\ntemplate<class T>\ninline void cudaFreeCheck(T** ptr, const char *file, int line) {\n    cudaError_t error = cudaFree(*ptr);\n    if (error != cudaSuccess) {\n        printf(\"[CUDA ERROR] at file %s:%d:\\n%s\\n\", file, line, cudaGetErrorString(error));\n        exit(EXIT_FAILURE);\n    }\n    *ptr = nullptr;\n}\n#define cudaFreeCheck(ptr) (cudaFreeCheck(ptr, __FILE__, __LINE__))\n\n// ----------------------------------------------------------------------------\n// CUDA Precision settings and defines\n\nenum PrecisionMode {\n    PRECISION_FP32,\n    PRECISION_FP16,\n    PRECISION_BF16\n};\n\n// Specific configurations based on the enabled precision\n#if defined(ENABLE_FP32)\ntypedef float floatX;\n#define PRECISION_MODE PRECISION_FP32\n// use fp16 (note: this may require gradient scaler, currently not implemented!)\n#elif defined(ENABLE_FP16)\ntypedef half floatX;\n#define PRECISION_MODE PRECISION_FP16\n#else // Default to bfloat16\ntypedef __nv_bfloat16 floatX;\n#define PRECISION_MODE PRECISION_BF16\n#endif\n\n// ----------------------------------------------------------------------------\n// Load and store with streaming cache hints\n// Older nvcc does not provide __ldcs and __stcs for bfloat16, despite these\n// actually just being unsigned shorts. We need to be careful here to only define\n// our own versions if none already exist, otherwise the compiler will complain.\n// If not, you easily get \"no viable overload\" (for sm52) and \"function already exists\" (sm_80)\n\n#if defined(ENABLE_BF16) && (__CUDACC_VER_MAJOR__ < 12) && !((__CUDA_ARCH__ >= 800) || !defined(__CUDA_ARCH__))\n__device__ floatX __ldcs(const floatX* address) {\n    unsigned short bf = __ldcs(reinterpret_cast<const unsigned short*>(address));\n    return __nv_bfloat16_raw{bf};\n}\n\n__device__ void __stcs(floatX* address, floatX value) {\n    __stcs(reinterpret_cast<unsigned short*>(address), ((__nv_bfloat16_raw)value).x);\n}\n#endif\n\n// ----------------------------------------------------------------------------\n// Profiler utils\n\nclass NvtxRange {\n public:\n    NvtxRange(const char* s) { nvtxRangePush(s); }\n    NvtxRange(const std::string& base_str, int number) {\n        std::string range_string = base_str + \" \" + std::to_string(number);\n        nvtxRangePush(range_string.c_str());\n    }\n    ~NvtxRange() { nvtxRangePop(); }\n};\n#define NVTX_RANGE_FN() NvtxRange nvtx_range(__FUNCTION__)\n\n// ----------------------------------------------------------------------------\n// Utilities to Read & Write between CUDA memory <-> files\n\n// copy num_bytes from device pointer src into file dest, using double buffering running on the given stream.\ninline void device_to_file(FILE* dest, void* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream) {\n    // allocate pinned buffer for faster, async transfer\n    char* buffer_space;\n    cudaCheck(cudaMallocHost(&buffer_space, 2*buffer_size));\n    // split allocation in two\n    void* read_buffer = buffer_space;\n    void* write_buffer = buffer_space + buffer_size;\n\n    // prime the read buffer; first copy means we have to wait\n    char* gpu_read_ptr = (char*)src;\n    size_t copy_amount = std::min(buffer_size, num_bytes);\n    cudaCheck(cudaMemcpyAsync(read_buffer, gpu_read_ptr, copy_amount, cudaMemcpyDeviceToHost, stream));\n    cudaCheck(cudaStreamSynchronize(stream));\n    size_t rest_bytes = num_bytes - copy_amount;\n    size_t write_buffer_size = copy_amount;\n    gpu_read_ptr += copy_amount;\n\n    std::swap(read_buffer, write_buffer);\n    // now the main loop; as long as there are bytes left\n    while(rest_bytes > 0) {\n        // initiate next read\n        copy_amount = std::min(buffer_size, rest_bytes);\n        cudaCheck(cudaMemcpyAsync(read_buffer, gpu_read_ptr, copy_amount, cudaMemcpyDeviceToHost, stream));\n        // while this is going on, transfer the write buffer to disk\n        fwriteCheck(write_buffer, 1, write_buffer_size, dest);\n        cudaCheck(cudaStreamSynchronize(stream));     // wait for both buffers to be ready.\n\n        std::swap(read_buffer, write_buffer);\n        rest_bytes -= copy_amount;\n        write_buffer_size = copy_amount;\n        gpu_read_ptr += copy_amount;\n    }\n\n    // make sure to write the last remaining write buffer\n    fwriteCheck(write_buffer, 1, write_buffer_size, dest);\n    cudaCheck(cudaFreeHost(buffer_space));\n}\n\n// copy num_bytes from file src into device pointer dest, using double buffering running on the given stream.\ninline void file_to_device(void* dest, FILE* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream) {\n     // allocate pinned buffer for faster, async transfer\n     // from the docs (https://developer.download.nvidia.com/compute/DevZone/docs/html/C/doc/html/group__CUDART__HIGHLEVEL_ge439496de696b166ba457dab5dd4f356.html)\n     // WC memory is a good option for buffers that will be written by the CPU and read by the device via mapped pinned memory or host->device transfers.\n    char* buffer_space;\n    cudaCheck(cudaMallocHost(&buffer_space, 2*buffer_size, cudaHostAllocWriteCombined));\n    // split allocation in two\n    void* read_buffer = buffer_space;\n    void* write_buffer = buffer_space + buffer_size;\n\n    // prime the read buffer;\n    char* gpu_write_ptr = (char*)dest;\n    size_t copy_amount = std::min(buffer_size, num_bytes);\n    freadCheck(read_buffer, 1, copy_amount, src);\n\n    size_t rest_bytes = num_bytes - copy_amount;\n    size_t write_buffer_size = copy_amount;\n    std::swap(read_buffer, write_buffer);\n\n    // now the main loop; as long as there are bytes left\n    while(rest_bytes > 0) {\n        // initiate next read\n        copy_amount = std::min(buffer_size, rest_bytes);\n        cudaCheck(cudaMemcpyAsync(gpu_write_ptr, write_buffer, write_buffer_size, cudaMemcpyHostToDevice, stream));\n        gpu_write_ptr += write_buffer_size;\n        // while this is going on, read from disk\n        freadCheck(read_buffer, 1, copy_amount, src);\n        cudaCheck(cudaStreamSynchronize(stream));     // wait for both buffers to be ready.\n\n        std::swap(read_buffer, write_buffer);\n        rest_bytes -= copy_amount;\n        write_buffer_size = copy_amount;\n    }\n\n    // copy the last remaining write buffer to gpu\n    cudaCheck(cudaMemcpyAsync(gpu_write_ptr, write_buffer, write_buffer_size, cudaMemcpyHostToDevice, stream));\n    cudaCheck(cudaStreamSynchronize(stream));\n    cudaCheck(cudaFreeHost(buffer_space));\n}\n\n#endif // CUDA_COMMON_H"
  },
  {
    "path": "llmc/cuda_utils.cuh",
    "content": "// Utilities for use in __device__ code\n\n#ifndef CUDA_UTILS_CUH\n#define CUDA_UTILS_CUH\n\n#include \"cuda_common.h\"\n\n// ----------------------------------------------------------------------------\n// Packed128 data structure that forces the compiler to use 128-bit loads/stores\n// in GPUs that support (the LDG.128 and STS.128 instructions)\n// This is a bit similar to the use of float4 in the case of 32-bit floats, but\n// supports arbitrary precision.\n\ntemplate<class ElementType>\nstruct alignas(16) Packed128 {\n    Packed128() = default;\n    __device__ explicit Packed128(int4 bits) {\n        static_assert(sizeof(bits) == sizeof(payload), \"Size mismatch.\");\n        memcpy(&payload, &bits, sizeof(bits));\n    }\n\n    __device__  static Packed128 constant(ElementType value) {\n        Packed128 result;\n        for(int k = 0; k < size; ++k) {\n            result.payload[k] = value;\n        }\n        return result;\n    }\n    __device__ static Packed128 zeros() {\n        return constant(0.f);\n    }\n    __device__ static Packed128 ones() {\n        return constant(1.f);\n    }\n\n    __device__ ElementType& operator[](int index) {\n        return payload[index];\n    }\n    __device__ const ElementType& operator[](int index) const {\n        return payload[index];\n    }\n    __device__ int4 get_bits() const {\n        int4 bits;\n        static_assert(sizeof(bits) == sizeof(payload), \"Size mismatch.\");\n        memcpy(&bits, &payload, sizeof(bits));\n        return bits;\n    }\n    static constexpr const size_t size = sizeof(int4) / sizeof(ElementType);\n    ElementType payload[size];\n};\n\n// load a Packed128 from an aligned memory address\ntemplate<class ElementType>\n__device__ Packed128<ElementType> load128(const ElementType* address) {\n    return Packed128<ElementType>{*reinterpret_cast<const int4*>(address)};\n}\n// load a Packed128 from an aligned memory address with streaming cache hint\ntemplate<class ElementType>\n__device__ Packed128<ElementType> load128cs(const ElementType* address) {\n    return Packed128<ElementType>{__ldcs(reinterpret_cast<const int4*>(address))};\n}\n// store a Packed128 to an aligned memory address\ntemplate<class ElementType>\n__device__ void store128(ElementType* target, Packed128<ElementType> value) {\n    *reinterpret_cast<int4*>(target) = value.get_bits();\n}\n// store a Packed128 to an aligned memory address with streaming cache hint\ntemplate<class ElementType>\n__device__ void store128cs(ElementType* target, Packed128<ElementType> value) {\n    __stcs(reinterpret_cast<int4*>(target), value.get_bits());\n}\n// store a Packed128 to an aligned memory address while caching in L2 but bypassing L1\ntemplate<class ElementType>\n__device__ void store128cg(ElementType* target, Packed128<ElementType> value) {\n    __stcg(reinterpret_cast<int4*>(target), value.get_bits());\n}\n\n// short-form typedefs\ntypedef Packed128<float> f128;\ntypedef Packed128<floatX> x128;\n\n// ----------------------------------------------------------------------------\n// DType support\n\n// enumerator to indentify the datatype of a tensor.\nenum class DType : uint8_t {\n    FP32, FP16, BF16\n};\n\n// Given a datatype enum, returns the underlying number of bytes\n// for a scalar of that type\nsize_t sizeof_dtype(DType type) {\n    switch (type) {\n        case DType::FP32:\n            return sizeof(float);\n        case DType::FP16:\n            return sizeof(half);\n        case DType::BF16:\n            return sizeof(nv_bfloat16);\n        default: // handle or get compiler warning\n            fprintf(stderr, \"Unknown datatype\\n\");\n            exit(EXIT_FAILURE);\n    }\n}\n\nDType dtype_of(float* f) { return DType::FP32; }\nDType dtype_of(nv_bfloat16 * f) { return DType::BF16; }\nDType dtype_of(half * f) { return DType::FP16; }\n\n\n\n// ----------------------------------------------------------------------------\n// Copy, cast functions\n\n// device functions and the kernel to cast data between types\ntemplate<typename Td, typename Ts>\n__device__ Td cast_value(Ts val);\n\ntemplate<>\n__device__ float cast_value<float, float>(float val) {\n    return val;\n}\n\ntemplate<>\n__device__ float cast_value<float, half>(half val) {\n    return __half2float(val);\n}\n\ntemplate<>\n__device__ float cast_value<float, __nv_bfloat16>(__nv_bfloat16 val) {\n    return __bfloat162float(val);\n}\n\ntemplate<typename Td, typename Ts>\n__global__ void copy_and_cast_kernel(Td* dst, const Ts* src, size_t n, ptrdiff_t stride_dst, ptrdiff_t stride_src) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    // need to try grid stride looping for more perf later\n    if (idx < n) {\n        dst[idx + stride_dst * blockIdx.y] = cast_value<Td, Ts>(src[idx + stride_src * blockIdx.y]);\n    }\n}\n\n// ----------------------------------------------------------------------------\n// Warp/Block communication primitives\n\n// warp-level reduction for summing values\n__device__ inline float warpReduceSum(float val) {\n    for (int offset = 16; offset > 0; offset /= 2) {\n        val += __shfl_xor_sync(0xFFFFFFFF, val, offset);\n    }\n    return val;\n}\n// warp-level reduction for finding the maximum value\n__device__ inline float warpReduceMax(float val) {\n    for (int offset = 16; offset > 0; offset /= 2) {\n        val = fmaxf(val, __shfl_xor_sync(0xFFFFFFFF, val, offset));\n    }\n    return val;\n}\n// requires all 32 threads in the warp to be active, but should work for any block size\n// uses non-dynamic shared memory so every call increases shared memory requirements by 128 bytes\n// the fact it's unique shared memory allows us to avoid an extra __syncthreads() call at the end\n// but if called inside a loop, the shared memory will be implicitly reused, so set final_sync to 1\nusing reduction_func_t = float (*) (float);\ntemplate<reduction_func_t warp_reduction>\n__device__ inline float blockReduce(float val, bool final_sync=false, float out_of_bounds=0.0f) {\n    // two reductions of up to 1024 threads:\n    // 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle)\n    __shared__ float shared_val[WARP_SIZE];\n    const int lane_id = threadIdx.x % WARP_SIZE;\n    const int warp_id = threadIdx.x / WARP_SIZE;\n    const int num_warps = blockDim.x / WARP_SIZE;\n\n    float warp_val = warp_reduction(val);\n    if (lane_id == 0) { shared_val[warp_id] = warp_val; }\n    __syncthreads();\n    warp_val = (lane_id < num_warps) ? shared_val[lane_id] : out_of_bounds;\n    float block_val = warp_reduction(warp_val);\n\n    if (final_sync) {\n        __syncthreads(); // only needed in loops when effectively reusing shared memory etc.\n    }\n    return block_val;\n}\n\n// Performs a _deterministic_ sum reduction. determinism is achieved by requiring that only\n// a single block be used.\ntemplate<class Float>\n__global__ void global_sum_single_block_kernel(float* result, const Float* values, size_t count) {\n    assert(gridDim.x == 1);     // only a single block!\n    float thread_sum = 0;\n    for(size_t index = threadIdx.x; index < count; index += blockDim.x) {\n        thread_sum += (float)values[index];\n    }\n\n    float reduction = blockReduce<warpReduceSum>(thread_sum, true);\n    if(threadIdx.x == 0) {\n        *result = reduction;\n    }\n}\n\ntemplate<class Float>\nvoid global_sum_deterministic(float* result, const Float* values, int count, cudaStream_t stream) {\n    global_sum_single_block_kernel<<<1, 1024, 0, stream>>>(result, values, count);\n    cudaCheck(cudaGetLastError());\n}\n\n// ----------------------------------------------------------------------------\n// memory management\n\n// allocate memory, preferrably on the device\n// returns a status code. 0 = OK, 1 = fell back to managed memory\nint cudaMallocConditionallyManaged(void** out, size_t bytes, const char *file, int line) {\n    // try to allocate\n    cudaError_t err = cudaMalloc(out, bytes);\n    if(err == cudaErrorMemoryAllocation) {\n        // if we OOM, fallback to a managed allocation. slower but at least won't crash.\n        cudaGetLastError(); // reset the error before the next API call\n        cudaCheck_(cudaMallocManaged(out, bytes), file, line);\n        cudaCheck_(cudaMemAdvise(*out, bytes, cudaMemAdviseSetPreferredLocation, cudaCpuDeviceId), file, line);\n        return 1;\n    } else {\n        cudaCheck_(err, file, line);\n        return 0;\n    }\n}\n\n#define cudaMallocConditionallyManaged(out, bytes)\\\n(cudaMallocConditionallyManaged((void**)out, bytes, __FILE__, __LINE__))\n\n// ----------------------------------------------------------------------------\n// Random Number Generation used in Stochastic Rounding\n\n// SquirrelNoise5 - Squirrel's Raw Noise utilities (version 5)\n// This gives us a random number from threadIdx/blockIdx + a single seed for the entire GPU\n// todo - possibly overkill and we don't need such high quality random numbers? (tbd)\n// http://eiserloh.net/noise/SquirrelNoise5.hpp\n__device__ __host__ constexpr unsigned int SquirrelNoise5(unsigned int positionX, unsigned int seed)\n{\n    constexpr unsigned int SQ5_BIT_NOISE1 = 0xd2a80a3f;\t// 11010010101010000000101000111111\n    constexpr unsigned int SQ5_BIT_NOISE2 = 0xa884f197;\t// 10101000100001001111000110010111\n    constexpr unsigned int SQ5_BIT_NOISE3 = 0x6C736F4B; // 01101100011100110110111101001011\n    constexpr unsigned int SQ5_BIT_NOISE4 = 0xB79F3ABB;\t// 10110111100111110011101010111011\n    constexpr unsigned int SQ5_BIT_NOISE5 = 0x1b56c4f5;\t// 00011011010101101100010011110101\n    unsigned int mangledBits = positionX;\n    mangledBits *= SQ5_BIT_NOISE1;\n    mangledBits += seed;\n    mangledBits ^= (mangledBits >> 9);\n    mangledBits += SQ5_BIT_NOISE2;\n    mangledBits ^= (mangledBits >> 11);\n    mangledBits *= SQ5_BIT_NOISE3;\n    mangledBits ^= (mangledBits >> 13);\n    mangledBits += SQ5_BIT_NOISE4;\n    mangledBits ^= (mangledBits >> 15);\n    mangledBits *= SQ5_BIT_NOISE5;\n    mangledBits ^= (mangledBits >> 17);\n    return mangledBits;\n}\n__device__ __host__ constexpr unsigned int Get2dNoiseUint(int indexX, int indexY, unsigned int seed)\n{\n    constexpr unsigned int PRIME_NUMBER = 198491317u; // Large prime number with non-boring bits\n    unsigned int x = static_cast<unsigned int>(indexX);\n    unsigned int y = static_cast<unsigned int>(indexY);\n\n    return SquirrelNoise5(x + (PRIME_NUMBER * y), seed);\n}\n\n// stochastic rounding built on top of Squirel Noise above (with seed updated per step via xorshift)\n__device__ __forceinline__ void stochastic_rounding(float in, __nv_bfloat16 *out, unsigned int seed) {\n    // todo - is this stochastic rounding *too good*? can we cut any corners?\n    // makes sure each thread gets a different random number\n    unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x * blockDim.x + blockIdx.y, seed);\n    unsigned int threshold = random & 0xFFFF;\n    unsigned int float_bits = __float_as_uint(in);\n    unsigned int rounded_bits = float_bits & 0x0000FFFF;\n    float_bits = (rounded_bits > threshold) ? (float_bits | 0xFFFF) : (float_bits  & ~0xFFFF);\n    *out = __float2bfloat16_rn(__uint_as_float(float_bits));\n}\n__device__ __forceinline__ void stochastic_rounding(float in, half *out, unsigned int random) {\n    *out = (float)in; // todo - implement this...\n}\n__device__ __forceinline__ void stochastic_rounding(float in, float *out, unsigned int random) {\n    *out = in; // dummy function for when floatX is float (FP32 mode)\n}\n\n#endif"
  },
  {
    "path": "llmc/cudnn_att.cpp",
    "content": "// all cudnn-related functions are in this file, so that they don't need to be recompiled everytime\n// we change some unrelated piece of the code.\n// TODO this currently duplicates some of the utilities from the main file\n\n#define NOMINMAX\n#include <unistd.h>\n#include \"cudnn_att.h\"\n#include <cudnn_frontend.h>\n\nnamespace fe = cudnn_frontend;\n\n// Specific configurations based on the enabled precision\n#if defined(ENABLE_FP32)\nstatic_assert(false, \"cuDNN is not supported in FP32 mode.\")\n// use fp16 (note: this may require gradient scaler, currently not implemented!)\n#elif defined(ENABLE_FP16)\n#define CUDNN_16BIT fe::DataType_t::HALF\n#else // Default to bfloat16\n#define CUDNN_16BIT fe::DataType_t::BFLOAT16\n#endif\n\nstatic cudnnHandle_t cudnn_handle;\nstatic size_t cudnn_workspace_size = 0; // dynamically allocated as needed (up to 256MiB!)\nstatic void* cudnn_workspace = NULL;\n\nstatic void cuDNNCheck(cudnnStatus_t error, const char *file, int line) {\n    if (error != CUDNN_STATUS_SUCCESS) {\n        printf(\"[CUDNN ERROR] at file %s:%d:\\n%s\\n\", file, line, cudnnGetErrorString(error));\n        exit(EXIT_FAILURE);\n    }\n};\n#define cuDNNCheck(err) (cuDNNCheck(err, __FILE__, __LINE__))\n\nstatic void checkCudnnFE(const fe::error_object& e, const char *file, int line) {\n    if(!e.is_good()) {\n        printf(\"[CUDNN ERROR] at file %s:%d:\\n%s\\n\", file, line, e.err_msg.c_str());\n        exit(EXIT_FAILURE);\n    }\n}\n#define checkCudnnFE(err) checkCudnnFE(err, __FILE__, __LINE__)\n\nenum UIDs {\n    Q_UID,\n    K_UID,\n    V_UID,\n    Attn_scale_UID,\n    O_UID,\n    Stats_UID,\n    dO_UID,\n    dQ_UID,\n    dK_UID,\n    dV_UID\n};\n\n// Need a cache because graph->build_operation_graph() is slow but everything else seems fast\nusing cache_type_fwd = std::map<std::tuple<int,int,int,int, int>, std::shared_ptr<fe::graph::Graph>>;\nusing cache_type_bwd = std::map<std::tuple<int,int,int,int>, std::shared_ptr<fe::graph::Graph>>;\n\n// Loosely based on cuDNN frontend samples functions and massively simplified\nauto lookup_cache_or_build_graph_fwd(int B,int H,int T,int HS, int is_inference_only) {\n\n    static cache_type_fwd user_maintained_cache_fwd;\n\n    auto key = std::make_tuple(B, H, T, HS, is_inference_only);\n\n    auto it = user_maintained_cache_fwd.find(key);\n    if (it != user_maintained_cache_fwd.end()) {\n        return it->second;\n    }\n\n    auto graph = std::make_shared<fe::graph::Graph>();\n    graph->set_io_data_type(CUDNN_16BIT)\n          .set_intermediate_data_type(fe::DataType_t::FLOAT)\n          .set_compute_data_type(fe::DataType_t::FLOAT);\n\n    // QKV is (B, T, 3, NH, HS) which cuDNN can handle directly without an external permute\n    auto Q = graph->tensor(fe::graph::Tensor_attributes().set_name(\"Q\")\n                               .set_dim({B, H, T, HS})\n                               .set_uid(Q_UID)\n                               .set_stride({3 * H * HS * T,  HS, 3 * H * HS, 1}));\n    auto K = graph->tensor(fe::graph::Tensor_attributes().set_name(\"K\")\n                               .set_dim({B, H, T, HS})\n                               .set_uid(K_UID)\n                               .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1}));\n    auto V = graph->tensor(fe::graph::Tensor_attributes().set_name(\"V\")\n                               .set_dim({B, H, T, HS})\n                               .set_uid(V_UID)\n                               .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1}));\n    auto attn_scale = graph->tensor(fe::graph::Tensor_attributes().set_name(\"attn_scale\")\n                               .set_dim({1, 1, 1, 1})\n                               .set_stride({1, 1, 1, 1})\n                               .set_uid(Attn_scale_UID)\n                               .set_is_pass_by_value(true)\n                               .set_data_type(fe::DataType_t::FLOAT));\n\n    auto sdpa_options = fe::graph::SDPA_attributes().set_name(\"flash_attention\");\n    sdpa_options.set_is_inference(is_inference_only);\n    sdpa_options.set_attn_scale(attn_scale);\n    sdpa_options.set_causal_mask(true);\n\n    // Create the graph operation and get the output tensors back\n    auto [O, stats] = graph->sdpa(Q, K, V, sdpa_options);\n\n    // Output is (B, T, NH, HS) BF16/FP16 and stats for backward pass is (B, NH, T) FP32\n    O->set_output(true).set_dim({B, H, T, HS}).set_stride({H * HS * T, HS, H * HS, 1}).set_uid(O_UID);\n\n    assert(stats == nullptr || is_inference_only == false);\n    if (is_inference_only == false) {\n        stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)\n                               .set_dim({B, H, T, 1})\n                               .set_stride({H * T, T, 1, 1})\n                               .set_uid(Stats_UID);\n    }\n\n    checkCudnnFE(graph->validate());\n\n    // Build the operation graph and execution part (this is the VERY SLOW PART)\n    checkCudnnFE(graph->build_operation_graph(cudnn_handle));\n    auto plans = graph->create_execution_plans({fe::HeurMode_t::A});\n    checkCudnnFE(graph->check_support(cudnn_handle));\n    checkCudnnFE(graph->build_plans(cudnn_handle));\n    // Reallocate the workspace if the required size is greater than the current workspace\n    // In H100 this may be around 16B\n    if (graph->get_workspace_size() > cudnn_workspace_size) {\n        if (cudnn_workspace_size > 0) {\n            cudaCheck(cudaFree(cudnn_workspace));\n        }\n        cudnn_workspace_size = graph->get_workspace_size();\n        cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size));\n    }\n\n    user_maintained_cache_fwd.insert({key, graph});\n\n    return graph;\n}\n\nauto lookup_cache_or_build_graph_bwd(int B, int NH, int T, int HS) {\n    static cache_type_bwd user_maintained_cache_bwd;\n\n    auto key = std::make_tuple(B, NH, T, HS);\n\n    auto it = user_maintained_cache_bwd.find(key);\n    if (it != user_maintained_cache_bwd.end()) {\n        return it->second;\n    }\n\n    auto graph = std::make_shared<fe::graph::Graph>();\n    graph->set_io_data_type(CUDNN_16BIT)\n          .set_intermediate_data_type(fe::DataType_t::FLOAT)\n          .set_compute_data_type(fe::DataType_t::FLOAT);\n\n    // (B, N, 3, NH, HS)\n    // must come from inp (which means we also need to convert THAT to FP16)\n    auto Q = graph->tensor(fe::graph::Tensor_attributes().set_name(\"Q\")\n                            .set_dim({B, NH, T, HS})\n                            .set_uid(Q_UID)\n                            .set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}));\n    auto K = graph->tensor(fe::graph::Tensor_attributes().set_name(\"K\")\n                            .set_dim({B, NH, T, HS})\n                            .set_uid(K_UID)\n                            .set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}));\n    auto V = graph->tensor(fe::graph::Tensor_attributes().set_name(\"V\")\n                            .set_dim({B, NH, T, HS})\n                            .set_uid(V_UID)\n                            .set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}));\n    auto O = graph->tensor(fe::graph::Tensor_attributes().set_name(\"O\")\n                            .set_dim({B, NH, T, HS})\n                            .set_uid(O_UID)\n                            .set_stride({NH * HS * T, HS, NH * HS, 1}));\n    auto dO = graph->tensor(fe::graph::Tensor_attributes().set_name(\"dO\")\n                            .set_dim({B, NH, T, HS})\n                            .set_uid(dO_UID)\n                            .set_stride({NH * HS * T, HS, NH * HS, 1}));\n\n    auto stats = graph->tensor(fe::graph::Tensor_attributes().set_name(\"stats\")\n                            .set_dim({B, NH, T, 1})\n                            .set_uid(Stats_UID)\n                            .set_stride({NH * T, T, 1, 1})\n                            .set_data_type(fe::DataType_t::FLOAT));\n    auto attn_scale = graph->tensor(fe::graph::Tensor_attributes().set_name(\"attn_scale\")\n                            .set_dim({1, 1, 1, 1})\n                            .set_stride({1, 1, 1, 1})\n                            .set_is_pass_by_value(true)\n                            .set_uid(Attn_scale_UID)\n                            .set_data_type(fe::DataType_t::FLOAT));\n    auto sdpa_backward_options = fe::graph::SDPA_backward_attributes().set_name(\"flash_attention_backward\")\n#if CUDNN_FRONTEND_MAJOR_VERSION > 1 || CUDNN_FRONTEND_MINOR_VERSION >= 5\n                            .set_deterministic_algorithm(true) // 1.5+ needs this for determinism\n#endif\n                            .set_causal_mask(true)\n                            .set_attn_scale(attn_scale);\n\n    // Create the graph operation and get the output tensors back\n    auto [dQ, dK, dV] = graph->sdpa_backward(Q, K, V, O, dO, stats, sdpa_backward_options);\n\n    dQ->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dQ_UID);\n    dK->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dK_UID);\n    dV->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dV_UID);\n\n    checkCudnnFE(graph->validate());\n\n    // Build the operation graph and execution part (this is the VERY SLOW PART)\n    checkCudnnFE(graph->build_operation_graph(cudnn_handle));\n    auto plans = graph->create_execution_plans({fe::HeurMode_t::A});\n    checkCudnnFE(graph->check_support(cudnn_handle));\n    checkCudnnFE(graph->build_plans(cudnn_handle));\n\n    // Reallocate the workspace if the required size is greater than the current workspace\n    // By default, cuDNN uses up to 256MiB of workspace, so we don't want to just allocate the maximum\n    if (graph->get_workspace_size() > cudnn_workspace_size) {\n        if (cudnn_workspace_size > 0) {\n            cudaCheck(cudaFree(cudnn_workspace));\n        }\n        cudnn_workspace_size = graph->get_workspace_size();\n        cudaCheck(cudaMalloc(&cudnn_workspace, cudnn_workspace_size));\n    }\n\n    user_maintained_cache_bwd.insert({key, graph});\n    return graph;\n}\n\nvoid attention_forward_cudnn(floatX* out,  // output: (B, T, NH, HS)\n                             float* stats, // output for backward pass: (B, NH, T)\n                             floatX* inp,  // input: (B, T, 3, NH, HS) QKV\n                             int B, int T, int NH, int C, cudaStream_t stream) {\n    NVTX_RANGE_FN();\n    int HS = C / NH; // number of features per head\n    bool is_inference_only = (stats == nullptr);\n\n    cuDNNCheck(cudnnSetStream(cudnn_handle, stream));\n\n    // Get graph and tensors from cache (or generate it on first use)\n    auto graph = lookup_cache_or_build_graph_fwd(B, NH, T, HS, is_inference_only);\n\n    // Prepare all the tensor pointers for executing the graph\n    void* devPtrQ = inp;\n    void* devPtrK = (inp + C);\n    void* devPtrV = (inp + 2 * C);\n    float attn_scale_cpu = 1.0 / sqrtf(HS);\n    void* devPtrO = out;\n\n    // Build variant pack\n    std::unordered_map<int64_t , void*> variant_pack = {\n        {Q_UID, devPtrQ}, {K_UID, devPtrK}, {V_UID, devPtrV}, {Attn_scale_UID, &attn_scale_cpu}, {O_UID, devPtrO}};\n\n    // Add the stats tensor unless we are only doing inference (only needed for backward pass)\n    if (is_inference_only == false) {\n        variant_pack[Stats_UID] = stats;\n    }\n\n    // Execute graph\n    checkCudnnFE(graph->execute(cudnn_handle, variant_pack, cudnn_workspace));\n    cudaCheck(cudaGetLastError());\n}\n\nvoid attention_backward_cudnn(floatX* dqkvr,                                       // output\n                              floatX* dout, floatX* qkvr, floatX* o, float* stats, // inputs\n                              int B, int T, int NH, int C, cudaStream_t stream) {\n    NVTX_RANGE_FN();\n    int HS = C / NH; // number of features per head\n\n    // Get graph and tensors from cache (or generate it on first use)\n    auto graph = lookup_cache_or_build_graph_bwd(B, NH, T, HS);\n\n    // Prepare all the tensor pointers for executing the graph\n    void* devPtrQ = qkvr;\n    void* devPtrK = (qkvr + NH * HS);\n    void* devPtrV = (qkvr + 2 * NH * HS);\n    void* devPtrO = o;\n    void* devPtrdO = dout;\n    void* devPtrStats = stats;\n    float attn_scale_cpu = 1.0 / sqrtf(HS);\n\n    void* devPtrdQ = dqkvr;\n    void* devPtrdK = (dqkvr + NH * HS);\n    void* devPtrdV = (dqkvr + 2 * NH * HS);\n\n    // Build variant pack that links each tensor to its data pointer\n    std::unordered_map<int64_t, void*> variant_pack = {\n        {Q_UID, devPtrQ}, {K_UID, devPtrK}, {V_UID, devPtrV}, {O_UID, devPtrO}, {dO_UID, devPtrdO}, {Stats_UID, devPtrStats},\n        {dQ_UID, devPtrdQ}, {dK_UID, devPtrdK}, {dV_UID, devPtrdV},\n        {Attn_scale_UID, &attn_scale_cpu}};\n\n    // Execute graph\n    cuDNNCheck(cudnnSetStream(cudnn_handle, stream));\n    checkCudnnFE(graph->execute(cudnn_handle, variant_pack, cudnn_workspace));\n    cudaCheck(cudaGetLastError());\n}\n\nvoid create_cudnn() {\n    cuDNNCheck(cudnnCreate(&cudnn_handle));\n}\n\nvoid destroy_cudnn() {\n    if (cudnn_workspace != NULL) { cudaCheck(cudaFree(cudnn_workspace)); }\n    cuDNNCheck(cudnnDestroy(cudnn_handle));\n}"
  },
  {
    "path": "llmc/cudnn_att.h",
    "content": "/*\ncuDNN (flash) attention\n*/\n#ifndef CUDNN_ATT_H\n#define CUDNN_ATT_H\n\n#include \"cuda_common.h\"\n\n// forward declarations of functions defined in cudnn_att.cpp\nvoid create_cudnn();\nvoid destroy_cudnn();\nvoid attention_forward_cudnn(floatX* out,  // output: (B, T, NH, HS)\n                             float* stats, // output for backward pass: (B, NH, T)\n                             floatX* inp,  // input: (B, T, 3, NH, HS) QKV\n                             int B, int T, int NH, int C, cudaStream_t stream);\n\nvoid attention_backward_cudnn(floatX* dqkvr,                                       // output\n                              floatX* dout, floatX* qkvr, floatX* o, float* stats, // inputs\n                              int B, int T, int NH, int C, cudaStream_t stream);\n\n#endif // CUDNN_ATT_H"
  },
  {
    "path": "llmc/dataloader.h",
    "content": "/*\nImplements:\n- DataLoader for model training. Reads and serves data shards.\n- EvalLoader for multiple-choice evaluation datasets, e.g. HellaSwag.\n*/\n#ifndef DATALOADER_H\n#define DATALOADER_H\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <stddef.h>\n#include <stdint.h>\n#include <assert.h>\n#include <string.h>\n// defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck\n// defines: mallocCheck\n#include \"utils.h\"\n#include \"rand.h\"\n\n// ----------------------------------------------------------------------------\n// implementation of glob for Windows is in dev/unistd.h\n#ifndef _WIN32\n#include <glob.h>\n#endif\n// ----------------------------------------------------------------------------\n// Distributed Data Loader\n#define HEADER_SIZE 256\n\ntypedef struct {\n    // variables related to distributed training\n    // each process/worker has to access different parts of the data\n    int process_rank;\n    int num_processes;\n    // batch and token information\n    size_t B;\n    size_t T;\n    size_t num_tokens; // total number of tokens\n    size_t shard_num_samples;  // total number of samples in the current shard per process\n    // shards and current position\n    glob_t glob_result; // stores the result of glob, for all shards we want to iterate\n    size_t current_shard_idx; // the current shard we are reading from\n    size_t current_sample_idx; // the current sample we are reading from\n    // file handle\n    FILE* tokens_file;\n    // data buffers\n    uint16_t* buffer; // we fread data from file into this buffer\n    int* inputs;  // input tokens into transformer\n    int* targets; // target tokens for the transformer\n    // random shuffle related variables\n    mt19937_state shuffle_rng;\n    int should_shuffle;\n    int* shard_indices;\n    int* intra_shard_indices;\n    // sizes in bytes\n    size_t total_batch_size_bytes;  // total across all processes\n    size_t local_batch_offset_bytes;  // inner-sample offset for this process\n    size_t header_bytes;  // header size in bytes\n    int64_t file_size_bytes;\n} DataLoader;\n\nint64_t dataloader_load_shard_(DataLoader *loader, int shard_index) {\n    if (loader->should_shuffle) {\n        shard_index = loader->shard_indices[shard_index];\n    }\n    // use the first glob match as the filename for now\n    const char* filename = loader->glob_result.gl_pathv[shard_index];\n    // open the input file for reading. also only a single file can be opened at a time\n    if (loader->tokens_file != NULL) {\n        fcloseCheck(loader->tokens_file);\n    }\n    loader->tokens_file = fopenCheck(filename, \"rb\");\n    // validate the header\n    int header[HEADER_SIZE];\n    freadCheck(header, sizeof(int), HEADER_SIZE, loader->tokens_file);\n    if (header[0] != 20240520) {\n        printf(\"Bad magic in the data file\\n\");\n        printf(\"---> HINT: Are you passing in a correct file?\\n\");\n        printf(\"---> HINT: The data encoding may have changed, re-run data prepro or refer again to README.\\n\");\n        exit(EXIT_FAILURE);\n    }\n    if (header[1] != 1) { printf(\"Bad version in data file\\n\"); exit(EXIT_FAILURE); }\n    int64_t ntok = header[2]; // number of tokens in the file\n    assert(ntok > 0); // we expect some tokens in the file. this should never trip, right?\n    // determine the file size and make sure it is consistent with the number of tokens\n    fseekCheck(loader->tokens_file, 0, SEEK_END); // seek to end of file\n    loader->file_size_bytes = ftell(loader->tokens_file); // read the offset, i.e. file size\n    fseekCheck(loader->tokens_file, 0, SEEK_SET); // seek back to the beginning\n    // we expect ntok in the file to be consistent with filesize, assert that is the case\n    int64_t expected_file_size = HEADER_SIZE * sizeof(int) + ntok * sizeof(uint16_t);\n    if (loader->file_size_bytes != expected_file_size) {\n        printf(\"Error: file size is not as expected\\n\");\n        exit(EXIT_FAILURE);\n    }\n    // -1 uint16_t due to us taking B*T+1 tokens but moving by B*T tokens\n    loader->shard_num_samples = (ntok * sizeof(uint16_t) - sizeof(uint16_t)) / loader->total_batch_size_bytes;\n    return ntok;\n}\n\nvoid prepare_intra_shard_indices_(DataLoader *loader) {\n    // shuffle the examples inside the shards\n    if (loader->intra_shard_indices != NULL) {\n        // in case shards have different number of samples / sizes\n        free(loader->intra_shard_indices);\n    }\n    loader->intra_shard_indices = (int*)mallocCheck(loader->shard_num_samples * sizeof(int));\n    init_identity_permutation(loader->intra_shard_indices, (int) loader->shard_num_samples);\n    random_permutation(loader->intra_shard_indices, (int) loader->shard_num_samples, &loader->shuffle_rng);\n}\n\nvoid dataloader_reset(DataLoader *loader) {\n    loader->current_shard_idx = 0;\n    loader->current_sample_idx = 0;\n\n    if (loader->should_shuffle) {  // shuffle the shards\n        random_permutation(loader->shard_indices, (int) loader->glob_result.gl_pathc, &loader->shuffle_rng);\n    }\n\n    dataloader_load_shard_(loader, (int) loader->current_shard_idx);\n\n    if (loader->should_shuffle) {\n        prepare_intra_shard_indices_(loader);\n    }\n}\n\nvoid dataloader_advance_(DataLoader *loader) {\n    if (loader->current_shard_idx == loader->glob_result.gl_pathc - 1) {\n        // if we are at the last shard, we reset the loader and start a new epoch\n        dataloader_reset(loader);\n        return;\n    }\n\n    // advance the loader by loading the next data shard and resetting the position\n    loader->current_shard_idx = (loader->current_shard_idx + 1) % loader->glob_result.gl_pathc;\n    loader->current_sample_idx = 0;\n    dataloader_load_shard_(loader, (int) loader->current_shard_idx);\n\n    if (loader->should_shuffle) {\n        prepare_intra_shard_indices_(loader);\n    }\n}\n\nvoid dataloader_init(DataLoader *loader,\n                     const char* filename_pattern,\n                     size_t B,\n                     size_t T,\n                     int process_rank,\n                     int num_processes,\n                     int should_shuffle) {\n    loader->process_rank = process_rank;\n    loader->num_processes = num_processes;\n    loader->B = B;\n    loader->T = T;\n    loader->tokens_file = NULL;\n    loader->should_shuffle = should_shuffle;\n    loader->header_bytes = HEADER_SIZE * sizeof(int);\n    loader->total_batch_size_bytes = ((loader->num_processes * (loader->B * loader->T)) * sizeof(uint16_t));\n    loader->local_batch_offset_bytes = loader->process_rank * loader->B * loader->T * sizeof(uint16_t);\n\n    // glob to get the list of files matching the pattern, these are our data shards\n    int glob_status = glob(filename_pattern, 0, NULL, &loader->glob_result);\n    if (glob_status != 0) {\n        printf(\"Error: failed to glob pattern: %s\\n\", filename_pattern);\n        exit(EXIT_FAILURE);\n    }\n    if (loader->glob_result.gl_pathc == 0) {\n        printf(\"Error: no files found matching the pattern: %s\\n\", filename_pattern);\n        exit(EXIT_FAILURE);\n    }\n\n    if (should_shuffle) {\n        mt19937_state shuffle_rng;\n        manual_seed(&shuffle_rng, 42 + process_rank);\n        loader->shuffle_rng = shuffle_rng;\n        loader->shard_indices = (int*)mallocCheck(loader->glob_result.gl_pathc * sizeof(int));\n        init_identity_permutation(loader->shard_indices, (int) loader->glob_result.gl_pathc);\n        loader->intra_shard_indices = NULL;  // dynamically allocated allowing different shard sizes\n    }\n\n    // inspect and validate all shards so we don't get any runtime errors later\n    // if too slow / too many shards, may wish to revisit later\n    int64_t ntok_total = 0;\n    for (int shard_index = 0; shard_index < loader->glob_result.gl_pathc; shard_index++) {\n        int64_t shard_ntok = dataloader_load_shard_(loader, shard_index);\n        // we need at least one batch/shard, the way things are written right now.\n        // can be relaxed a lot later.\n        assert(shard_ntok >= (int64_t) (num_processes * B * T + 1));\n        ntok_total += shard_ntok;\n    }\n    // debugging prints\n    // printf(\"DataLoader: filename_pattern: %s\\n\", filename_pattern);\n    // printf(\"DataLoader: Found %ld tokens across %zu shards\\n\", ntok_total, loader->glob_result.gl_pathc);\n\n    // allocate all the space we'll need\n    loader->buffer = (uint16_t*)mallocCheck((B * T + 1) * sizeof(uint16_t));\n    loader->inputs = (int*)mallocCheck(B * T * sizeof(int));\n    loader->targets = (int*)mallocCheck(B * T * sizeof(int));\n    loader->num_tokens = ntok_total;\n\n    // reset the loader, to initialize it\n    dataloader_reset(loader);\n}\n\nvoid dataloader_load_batch(DataLoader* loader) {\n    assert(!loader->should_shuffle || (loader->should_shuffle && loader->intra_shard_indices != NULL));\n    assert(loader->current_sample_idx < loader->shard_num_samples);\n    size_t idx = loader->should_shuffle ? loader->intra_shard_indices[loader->current_sample_idx] : loader->current_sample_idx;\n    size_t global_batch_offset_bytes = idx * loader->total_batch_size_bytes;\n    int64_t current_offset = loader->header_bytes + global_batch_offset_bytes + loader->local_batch_offset_bytes;\n\n    size_t B = loader->B;\n    size_t T = loader->T;\n    // read B*T+1 uint16_t tokens from the file into buffer\n    fseekCheck(loader->tokens_file, (int) current_offset, SEEK_SET);\n    freadCheck(loader->buffer, sizeof(uint16_t), B*T+1, loader->tokens_file);\n    // decode the buffer into inputs and targets (cast to int)\n    for (int i = 0; i < B*T; i++) {\n        loader->inputs[i] = (int)loader->buffer[i];\n        loader->targets[i] = (int)loader->buffer[i+1];\n    }\n}\n\nvoid dataloader_next_batch(DataLoader *loader) {\n    // if the next batch would go past the end of the file, advance the loader\n    if (loader->current_sample_idx >= loader->shard_num_samples) {\n        dataloader_advance_(loader);\n    }\n    dataloader_load_batch(loader);\n    loader->current_sample_idx += 1;\n}\n\n\nvoid dataloader_resume(DataLoader *loader, size_t current_shard_idx, size_t current_sample_idx) {\n    // used during model resumption (-y 1) flag\n    loader->current_shard_idx = current_shard_idx;\n    loader->current_sample_idx = current_sample_idx;\n    dataloader_load_shard_(loader, (int) loader->current_shard_idx);\n}\n\nvoid dataloader_free(DataLoader *loader) {\n    free(loader->buffer);\n    free(loader->inputs);\n    free(loader->targets);\n    if (loader->should_shuffle) {\n        free(loader->shard_indices);\n        free(loader->intra_shard_indices);\n    }\n    fcloseCheck(loader->tokens_file);\n    globfree(&loader->glob_result);\n}\n\n// ----------------------------------------------------------------------------\n// Distributed Eval Loader\n// Many evals (like) HellaSwag and MMLU are multiple-choice\n// where there are 4 possible continuations and a label for the correct one\n// We want to load and serve these style of evals\n/*\nCopy pasting the section on the eval datafile format, from data_common.py:\n- First comes a header with 256 int32s\n- The examples follow, each example is a stream of uint16_t:\n    - <START_EXAMPLE> delimiter of 2**16-1, i.e. 65,535\n    - <EXAMPLE_BYTES>, bytes encoding this example, allowing efficient skip to next\n    - <EXAMPLE_INDEX>, the index of the example in the dataset\n    - <LABEL>, the index of the correct completion\n    - <NUM_COMPLETIONS>, indicating the number of completions (usually 4)\n    - <NUM><CONTEXT_TOKENS>, where <NUM> is the number of tokens in the context\n    - <NUM><COMPLETION_TOKENS>, repeated NUM_COMPLETIONS times\n*/\n\n// for now, could relax later\n#define ASSUMED_NUM_COMPLETIONS 4\n// helper macro for ceildiv\n#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))\n\ntypedef struct {\n    // variables related to distributed training\n    // each process/worker has to access different parts of the data\n    int process_rank;\n    int num_processes;\n    // hyperparameters. use size_t to prevent overflow\n    size_t B; // (micro) batch size dimension of the tensor that feeds into the model\n    size_t T; // maximum context length of the model\n    // input handling and its state\n    FILE* eval_file;\n    uint16_t* buffer; // we fread data from file into this buffer\n    // public variables that could be accessed from outside\n    int num_examples; // in total across all processes\n    int num_batches; // to process the entire dataset across all processes\n    int start_example_index; // the assignment of work for this process, start\n    int end_example_index; // and end. start is inclusive, end is exclusive\n    int current_example_index; // the next example we would read\n    int* inputs;  // input tokens into transformer\n    int* targets; // target tokens for the transformer\n    char* mask; // mask=1 at all completion token locations\n    int* label; // the correct completion labels\n    int num_completions; // number of completions for this example\n} EvalLoader;\n\nvoid evalloader_reset(EvalLoader *loader) {\n    // we have to be careful that each process starts at the correct offset.\n    // For example if there are N examples in the file and 4 processes,\n    // then process 0 should start at 0, process 1 at N/4, process 2 at N/2, etc.\n    // determine how much work there is for all processes\n    int examples_per_process = CEIL_DIV(loader->num_examples, loader->num_processes);\n    int can_fit_examples = (int) (loader->B / ASSUMED_NUM_COMPLETIONS);\n    if (can_fit_examples == 0) {\n        // this could be fixed in the future, but for now keeping it simple and throw error when B too low\n        printf(\"HellaSwag EvalLoader: batch size %zu is < %d\\n\", loader->B, ASSUMED_NUM_COMPLETIONS);\n        printf(\"---> HINT: Disable HellaSwag eval with -h 0, or increase batch size with -b\\n\");\n        exit(EXIT_FAILURE);\n    }\n    loader->num_batches = CEIL_DIV(examples_per_process, can_fit_examples);\n    // determine the start and end example indices for this process\n    loader->start_example_index = examples_per_process * loader->process_rank;\n    loader->end_example_index = examples_per_process * (loader->process_rank + 1);\n    // crop the end example index to the total number of examples\n    if (loader->end_example_index > loader->num_examples) {\n        loader->end_example_index = loader->num_examples;\n    }\n    // now seek through the file to the start of that example\n    // utilize <EXAMPLE_BYTES> for efficiency\n    int64_t header_bytes = HEADER_SIZE * sizeof(int);\n    fseekCheck(loader->eval_file, (int) header_bytes, SEEK_SET);\n    for (int i = 0; i < loader->start_example_index; i++) {\n        uint16_t example_header[3];\n        // read 3 uint16_t values: <START_EXAMPLE>, <EXAMPLE_BYTES>, <EXAMPLE_INDEX>\n        freadCheck(&example_header[0], sizeof(uint16_t), 3, loader->eval_file);\n        // validate the <START_EXAMPLE> delimiter\n        assert(example_header[0] == 65535); // <START_EXAMPLE> delimiter\n        // validate the <EXAMPLE_INDEX>\n        assert(example_header[2] == i); // <EXAMPLE_INDEX> should match the loop index\n        // skip to the next example, keeping in mind that we already read the header\n        size_t remaining_bytes = example_header[1] - sizeof(uint16_t) * 3;\n        assert(remaining_bytes > 0); // we expect some bytes in the example\n        fseekCheck(loader->eval_file, (int) remaining_bytes, SEEK_CUR);\n    }\n    // now we are at the start of the example we want to start at, pointing at <START_EXAMPLE>\n    loader->current_example_index = loader->start_example_index;\n}\n\nvoid evalloader_init(EvalLoader *loader,\n                     const char* filename,\n                     size_t B,\n                     size_t T,\n                     int process_rank,\n                     int num_processes) {\n    loader->process_rank = process_rank;\n    loader->num_processes = num_processes;\n    loader->B = B;\n    loader->T = T;\n\n    // open the file and validate the header\n    loader->eval_file = fopenCheck(filename, \"rb\");\n    // validate the header\n    int header[HEADER_SIZE];\n    freadCheck(header, sizeof(int), HEADER_SIZE, loader->eval_file);\n    if (header[0] != 20240522) { printf(\"Bad magic in eval file\\n\"); exit(EXIT_FAILURE); }\n    if (header[1] != 1) { printf(\"Bad version in data file\\n\"); exit(EXIT_FAILURE); }\n    loader->num_examples = header[2]; // number of examples in the file\n    assert(loader->num_examples >= num_processes); // avoid headaches for now\n    size_t longest_example_bytes = header[3]; // longest example in the file\n    // basic sensibility check we could relax later. but roughly each example\n    // contains the prompt (or \"context\") and 4 completions, all of these have to be\n    // up to T tokens, and their tokens are uint16_t (so 2 bytes/token).\n    // There's a few more things in each example but they are minor.\n    // So longest example should be roughly this. Just trying to make sure it's sensible.\n    assert(longest_example_bytes > 0 && longest_example_bytes < (1+ASSUMED_NUM_COMPLETIONS)*T*2);\n\n    // allocate all the space we'll need\n    int can_fit_examples = (int) (B / ASSUMED_NUM_COMPLETIONS);\n    loader->buffer = (uint16_t*)mallocCheck(longest_example_bytes);\n    loader->inputs = (int*)calloc(B * T, sizeof(int));\n    loader->targets = (int*)calloc(B * T, sizeof(int));\n    loader->mask = (char*)mallocCheck(B * T * sizeof(char));\n    loader->label = (int*)mallocCheck(can_fit_examples * sizeof(int));\n\n    // reset the loader, to initialize it\n    evalloader_reset(loader);\n}\n\nvoid evalloader_next_example_(EvalLoader *loader, int example_batch_index) {\n    // this function populates the inputs, targets, mask, and label fields for one example\n    // because every (B,T) tensor can fit multiple examples and we want to take advantage,\n    // we also pass in the example_batch_index to indicate which example in the batch we are loading\n    // and each example takes up ASSUMED_NUM_COMPLETIONS rows in the batch\n    size_t B = loader->B;\n    size_t T = loader->T;\n    int batch_dim_offset = example_batch_index * ASSUMED_NUM_COMPLETIONS;\n    // read the current example header\n    uint16_t example_header[3];\n    freadCheck(&example_header[0], sizeof(uint16_t), 3, loader->eval_file);\n    // validate the <START_EXAMPLE> delimiter\n    assert(example_header[0] == 65535); // <START_EXAMPLE> delimiter\n    // validate the <EXAMPLE_INDEX>\n    assert(example_header[2] == loader->current_example_index); // <EXAMPLE_INDEX> should match the loop index\n    assert(example_header[2] >= loader->start_example_index && example_header[2] < loader->end_example_index);\n    // read the rest of the example (we have space for 3 more uint16_t values in buffer, it's ok)\n    size_t example_bytes = example_header[1] - sizeof(uint16_t) * 3;\n    // read example_bytes into buffer. careful that this is actually in the units of bytes\n    freadCheck(loader->buffer, sizeof(char), example_bytes, loader->eval_file);\n    // process the example label\n    int label = (int)loader->buffer[0];\n    int can_fit_examples = (int) (loader->B / ASSUMED_NUM_COMPLETIONS);\n    assert(label >= 0 && label < ASSUMED_NUM_COMPLETIONS); // we expect the label to be in [0, 4) for right now\n    assert(example_batch_index >= 0 && example_batch_index < can_fit_examples);\n    loader->label[example_batch_index] = label; // store for output\n    // process the number of completions\n    int num_completions = (int)loader->buffer[1];\n    assert(num_completions == ASSUMED_NUM_COMPLETIONS); // we expect 4 completions for now\n    assert(batch_dim_offset + num_completions <= B); // we expect to fit in the batch\n    loader->num_completions = num_completions; // store for output\n    // process the context\n    // the context is shared for all completions, so we insert it into all data rows equally\n    int context_length = (int)loader->buffer[2];\n    uint16_t *context_tokens_start = &loader->buffer[3]; // where the tokens start\n    assert(context_length > 0 && context_length < T); // context is non-empty and up to T\n    for (int b = 0; b < num_completions; b++) {\n        for (int i = 0; i < context_length; i++) {\n            int boff = batch_dim_offset + b;\n            int tok_cur = (int)context_tokens_start[i];\n            loader->inputs[boff * T + i] = tok_cur;\n        }\n    }\n    // process the completions, insert them in their row, right after the (shared) context\n    uint16_t *completions_iter = loader->buffer + 3 + context_length;\n    for (int c = 0; c < num_completions; c++) {\n        int coff = batch_dim_offset + c;\n        int completion_length = (int)completions_iter[0];\n        uint16_t *completion_tokens_start = completions_iter + 1;\n        assert(completion_length > 0 && context_length + completion_length < T); // things fit?\n        for (int i = 0; i < completion_length; i++) {\n            int tok_cur = (int)completion_tokens_start[i];\n            // at inputs, the completions simply follow the context\n            loader->inputs[coff * T + context_length + i] = tok_cur;\n            // at targets things start to get tricky\n            // we expect the last context token to predict the first completion token\n            // and then onwards from there.\n            loader->targets[coff * T + context_length + i - 1] = tok_cur;\n            // and at these positions, we want to set mask=1, because these are the\n            // positions where we want to average the loss, in each row, to determine\n            // its overall probability of following the context.\n            loader->mask[coff * T + context_length + i - 1] = 1;\n        }\n        completions_iter += 1 + completion_length; // move to the next completion\n    }\n    // advance the current example to point to the next one we'd load\n    loader->current_example_index += 1;\n}\n\nvoid evalloader_next_batch(EvalLoader *loader) {\n    size_t B = loader->B;\n    size_t T = loader->T;\n    // init mask to zeros, no need to do it for inputs & targets, the values where the mask\n    // is set will be correctly overwritten every time.\n    memset(loader->mask, 0, B * T * sizeof(char));\n    // ok here is the problem we are solving\n    // we have a batch dimension of B, which we want to take full advantage of\n    // each example has some number of completions (usually 4)\n    // so we want to pack as many examples into rows of B as we can fit\n    int can_fit_examples = (int) (B / ASSUMED_NUM_COMPLETIONS); // how many examples can we fit in the batch?\n    for (int i = 0; i < can_fit_examples; i++) {\n        if (loader->current_example_index >= loader->end_example_index) {\n            break; // this process has exhausted its work, noop from here on\n        }\n        evalloader_next_example_(loader, i);\n    }\n}\n\nint evalloader_stat_losses(EvalLoader *loader, float* losses) {\n    // compute statistics of losses (B*T) resulting from a forward pass\n    // on a batch that was constructed from EvalLoader\n    // putting this functionality here because it is tightly coupled\n    // with how we construct and represent the data batches.\n    // returns the number of correct examples in this batch.\n    int correct = 0;\n    size_t B = loader->B;\n    size_t T = loader->T;\n    // iterate the examples in this batch\n    int can_fit_examples = (int) (B / ASSUMED_NUM_COMPLETIONS);\n    for (int i = 0; i < can_fit_examples; i++) {\n        float min_loss = 0.0f;\n        int min_loss_index = -1;\n        char active = 0; // is this example active or fully empty?\n        // iterate the completions in this example\n        for (int b = 0; b < ASSUMED_NUM_COMPLETIONS; b++) {\n            int boff = i * ASSUMED_NUM_COMPLETIONS + b;\n            // evaluate the quality of this completion\n            // its quality is simply the average loss over the tokens\n            float average_loss = 0.0f;\n            int count = 0;\n            for (int t = 0; t < T; t++) {\n                char mask = loader->mask[boff * T + t];\n                if (mask == 1) {\n                    active = 1;\n                    average_loss += losses[boff * T + t];\n                    count++;\n                }\n            }\n            if (count > 0) { average_loss /= count; }\n            if (b == 0 || average_loss < min_loss) {\n                min_loss = average_loss;\n                min_loss_index = b;\n            }\n        }\n        if (active && (min_loss_index == loader->label[i])) {\n            correct += 1;\n        }\n    }\n    return correct;\n}\n\nvoid evalloader_free(EvalLoader *loader) {\n    free(loader->buffer);\n    free(loader->inputs);\n    free(loader->targets);\n    free(loader->mask);\n    free(loader->label);\n    fcloseCheck(loader->eval_file);\n}\n\n#endif // DATALOADER_H"
  },
  {
    "path": "llmc/encoder.cuh",
    "content": "/*\nThe GPT-2 Encoder, which combines two encodings: token and position\nIn the forward pass, both encodings are added together\nIn the backward pass, the gradients flow to both, handled by different kernels\n*/\n#include <assert.h>\n#include <stdint.h>\n#include <utility>              // std::pair\n#include <vector>\n#include <algorithm>\n#include <unordered_map>\n// llmc internal imports\n#include \"cuda_common.h\"\n#include \"cuda_utils.cuh\"\n\n// ----------------------------------------------------------------------------\n// CUDA kernels\n\n__global__ void encoder_forward_kernel3(floatX* out,\n                               const int* inp, const floatX* wte, const floatX* wpe,\n                               int B, int T, int C) {\n    int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;\n    int N = B * T * C;\n    if (idx >= N) { return; }\n\n    int bt = idx / C;\n    int b = bt / T;\n    int t = bt % T;\n    int c = idx % C;\n\n    int ix = inp[b * T + t];\n\n    floatX* out_btc = out + b * T * C + t * C + c;\n    const floatX* wte_ix = wte + ix * C + c;\n    const floatX* wpe_tc = wpe + t * C + c;\n\n    x128 packed_out;\n    x128 wte128 = load128cs(wte_ix);\n    x128 wpe128 = load128cs(wpe_tc);\n    for (int k = 0; k < x128::size; k++) {\n        packed_out[k] = (floatX)((float)wte128[k] + (float)wpe128[k]);\n    }\n    store128(out_btc, packed_out);\n}\n\ntemplate <int BLOCK_SIZE=256>\n__global__ void wte_backward_kernel(floatX* dwte,\n                                    const int4* bucket_info, const int* workload_indices, const floatX* dout, const int* inp,\n                                    unsigned int seed, int B, int T, int C) {\n    // In order to be deterministic, we preprocess the inputs on the cpu into \"buckets\"\n    // Each bucket corresponds to (WARP_SIZE * x128::size) channels for a single vocabulary token\n    // Each thread handles x128::size channels, e.g. 256 per warp for BF16\n    // Each block handles (BLOCK_SIZE / WARP_SIZE) elements in a single bucket in parallel\n    // If a bucket has less than 8 elements, some warps will return immediately\n    // If a bucket has more than 8 elements, we will loop over all of them\n    // The buckets are sorted on the CPU so the largest buckets start 1st\n    int bucket = blockIdx.x;\n    int warp_id = threadIdx.x / WARP_SIZE;\n    int lane_id = threadIdx.x % WARP_SIZE;\n    int c_per_warp = WARP_SIZE * x128::size;\n\n    int bucket_start_idx = bucket_info[bucket].x;\n    int bucket_size = bucket_info[bucket].y;\n    int bucket_ix = bucket_info[bucket].z;\n    int c = bucket_info[bucket].w * c_per_warp + (lane_id * x128::size);\n\n    // Each thread handles \"x128::size\" channels, so at fp8, each warp would handle 512 channels\n    // If C is not a multiple of this (e.g. 768), some buckets/c_groups cannot use the entire warp\n    if (c >= C) { return; }\n    // Exit early if this is a small bucket and this warp doesn't have any items to process\n    if (warp_id >= bucket_size) { return; }\n\n    float accum[x128::size] = {0.0f};\n    __shared__ float accum_shared[x128::size * BLOCK_SIZE];\n\n    for(int item = warp_id; item < bucket_size; item += BLOCK_SIZE/WARP_SIZE) {\n        int bt = workload_indices[bucket_start_idx + item];\n\n        const floatX* dout_btc = dout + bt * C + c;\n        x128 packed_inp1 = load128cs(dout_btc);\n        for (int k = 0; k < packed_inp1.size; k++) {\n            accum[k] += (float)packed_inp1[k];\n        }\n    }\n\n    if (warp_id != 0) {\n        // we accumulate into warp 0, so only the other warps need to write to shared memory\n        for (int k = 0; k < x128::size; k++) {\n            accum_shared[threadIdx.x + k * BLOCK_SIZE] = accum[k];\n        }\n        return; // only warp 0 is needed after writing to shared memory\n    }\n\n    // Read dwte for warp 0 even if other warps are not finished yet to maximise latency tolerance\n    floatX* dwte_ix = dwte + bucket_ix * C + c;\n    x128 packed_in_out = load128(dwte_ix);\n\n    // note: threads which have returned are considered synchronised by CUDA so no risk of deadlock\n    __syncthreads();\n\n    // Accumulate into warp 0's registers by reading the values of the other warps in shared memory\n    for (int i = threadIdx.x+WARP_SIZE; i < min(BLOCK_SIZE, bucket_size*WARP_SIZE); i += WARP_SIZE) {\n        for (int k = 0; k < x128::size; k++) {\n            accum[k] += accum_shared[i + k * BLOCK_SIZE];\n        }\n    }\n\n    // Add the result to dwte and write back to global memory (read-modify-write)\n    for (unsigned int k = 0; k < x128::size; k++) {\n        // We use stochastic rounding to go from FP32 to BF16\n        // The seed is deterministic and unique for each parameter to guarantee we have determinism AND\n        // to avoid **potential** issues with positionX int SquirrelNoise5 argument overflowing which is UB\n        // and that somehow messing the quality of random numbers\n        stochastic_rounding(accum[k] + (float)packed_in_out[k], &packed_in_out[k], seed + bucket * WARP_SIZE + threadIdx.x + k);\n    }\n    store128(dwte_ix, packed_in_out);\n}\n\n__global__ void wpe_backward_kernel(floatX* dwpe,\n                                    const floatX* dout, const int* inp,\n                                    int B, int T, int C, unsigned int seed) {\n    // Each thread handles x128::size \"channel positions\", e.g. 256 per warp for BF16\n    // For gpt2-124M BF16, C=768 and T=1024, so 3 warps per channel and 3072 warps in total\n    // For each \"channel position\" we sum the gradients for every batch at that C/T element\n    // This way each dwte element is only updated once, and the kernel is fully deterministic!\n    // The previous kernel was not deterministic, as batches were aggregated with atomicAdd\n    int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;\n    if (idx >= T * C) { return; }\n\n    // if C is not a multiple of WARP_SIZE*x128::size, it's OK for some warps to handle multiple t\n    int t = idx / C;\n    int c = idx % C;\n    float accum[x128::size] = {0.0f};\n\n    for (int b = 0; b < B; b++) {\n        x128 packed_dout = load128cs(dout + (b * T * C) + (t * C) + c); // will never be read again\n        for (int k = 0; k < x128::size; k++) {\n            accum[k] += (float)packed_dout[k];\n        }\n    }\n\n    floatX* dwpe_tc = dwpe + (t * C) + c;\n    x128 packed_dwpe = load128(dwpe_tc);\n    for (unsigned int k = 0; k < x128::size; k++) {\n        // We use stochastic rounding to go from FP32 to BF16\n        // The seed is deterministic and unique for each parameter to guarantee we have determinism AND\n        // to avoid **potential** issues with positionX int SquirrelNoise5 argument overflowing which is UB\n        // and that somehow messing the quality of random numbers\n        stochastic_rounding(accum[k] + (float)packed_dwpe[k], &packed_dwpe[k], seed + idx + k);\n    }\n    store128(dwpe_tc, packed_dwpe);\n}\n\n// ----------------------------------------------------------------------------\n// kernel launchers\n\nvoid encoder_forward(floatX* out,\n                     const int* inp, const floatX* wte, const floatX* wpe,\n                     int B, int T, int C, cudaStream_t stream) {\n    NVTX_RANGE_FN();\n    const int block_size = 256;\n    const int N = B * T * C;\n    const int grid_size = CEIL_DIV(N, (int)(block_size * x128::size));\n    encoder_forward_kernel3<<<grid_size, block_size, 0, stream>>>(out, inp, wte, wpe, B, T, C);\n    cudaCheck(cudaGetLastError());\n}\n\n// Fully deterministic (see comments in wte_backward_kernel and wpe_backward_kernel for more details)\nvoid encoder_backward(floatX* dwte, floatX* dwpe, floatX* scratch, // gpu outputs & scratch\n                      int* workload_indices, int4* bucket_info,    // cpu scratch buffers\n                      const floatX* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs\n                      int B, int T, int C, unsigned int seed, cudaStream_t stream) {\n    NVTX_RANGE_FN();\n\n    // Launch wpe kernel first (so it runs on the GPU in parallel with the CPU pre-processing for wte)\n    const int block_size = 256;\n    const int N = T * C / x128::size;\n    const int grid_size = CEIL_DIV(N, block_size);\n    wpe_backward_kernel<<<grid_size, block_size, 0, stream>>>(dwpe, dout, inp, B, T, C, seed);\n    cudaCheck(cudaGetLastError());\n\n    // check the GPU scratch buffer is large enough to hold the bucket info and workload indices\n    // todo - this is trivially true given hardcoded scratch buffer size here, is this useful?\n    int num_c_groups = CEIL_DIV(C, x128::size * WARP_SIZE);\n    assert(B*T*num_c_groups * (sizeof(int4)+sizeof(int)) <= B*T*3*C * sizeof(floatX));\n\n    // Step 1: Sort inputs into buckets\n    int total_items = 0;\n    std::unordered_map<uint64_t, std::vector<uint64_t>> buckets;\n    for (uint64_t bt = 0; bt < B * T; bt++) {\n        for (uint64_t c_group = 0; c_group < num_c_groups; c_group++) {\n            // todo - passing c_group/inputs_cpu[bt] in data to avoid a second hash lookup is a bit hacky\n            uint64_t data = bt + (c_group<<32ULL) + ((uint64_t)inputs_cpu[bt]<<42ULL);\n            buckets[c_group + num_c_groups * inputs_cpu[bt]].push_back(data);\n            total_items++;\n        }\n    }\n\n    // Step 2: Sort buckets by size in descending order\n    // this is so the largest buckets are processed first by the GPU\n    // otherwise, if they started late, they would still be running with the rest of the GPU idle\n    std::vector<std::pair<uint64_t, std::vector<uint64_t>>> sortedBuckets(buckets.begin(), buckets.end());\n    std::sort(sortedBuckets.begin(), sortedBuckets.end(), // ugly because we don't have a typedef for the std::pair\n              [](const std::pair<uint64_t, std::vector<uint64_t>>& a, const std::pair<uint64_t, std::vector<uint64_t>>& b) {\n                  return a.second.size() > b.second.size();\n              });\n\n    int num_buckets = buckets.size();\n    int bucket_index = 0;\n    int workload_index = 0;\n    for (const auto& bucket : sortedBuckets) {\n        bucket_info[bucket_index].x = workload_index; // bucket start\n        bucket_info[bucket_index].y = bucket.second.size(); // bucket size\n        bucket_info[bucket_index].z = (bucket.second[0] >> 42ULL) & ((1ULL<<20ULL)-1); // bucket ix\n        bucket_info[bucket_index].w = (bucket.second[0] >> 32ULL) & ((1ULL<<10ULL)-1); // bucket c\n\n        for (uint64_t idx : bucket.second) {\n            workload_indices[workload_index++] = (int)(idx & ((1ULL<<31ULL)-1ULL));\n        }\n        bucket_index++;\n    }\n\n    // Step 3: Copy data from host to device (async until the last one to avoid synchronising CPU/GPU twice)\n    // todo - could use CUDA events (even without streams) to avoid CPU/GPU synchronisation completely\n    int4* d_bucket_info = (int4*)scratch;\n    int*  d_workload_indices = (int*)(scratch + B*T*num_c_groups * sizeof(int4));\n    cudaCheck(cudaMemcpyAsync(d_bucket_info, bucket_info, num_buckets * sizeof(int4), cudaMemcpyHostToDevice, stream));\n    cudaCheck(cudaMemcpyAsync(d_workload_indices, workload_indices, total_items * sizeof(int), cudaMemcpyHostToDevice, stream));\n\n    // Launch wte kernel\n    // todo - profile block sizes on more content (depends on number of buckets and on GPU?)\n    wte_backward_kernel<256><<<num_buckets, 256, 0, stream>>>(dwte, d_bucket_info, d_workload_indices, dout, inp, seed, B, T, C);\n    cudaCheck(cudaGetLastError());\n}\n"
  },
  {
    "path": "llmc/fused_classifier.cuh",
    "content": "/*\nFused Classifier:\n- Forwards the Cross Entropy Loss\n- Never materializes the full normalized logits, only at the target label\n- (fusion) Also kicks off the backward pass, because everything is already loaded\n*/\n// llmc internal imports\n#include \"cuda_common.h\"\n#include \"cuda_utils.cuh\"\n\n// ----------------------------------------------------------------------------\n// CUDA kernels\n\nstruct SoftmaxParams {\n    float Scale;\n    float Offset;\n};\n\n__device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* inp, int V, int P) {\n    // same but not float4\n    // one row of inp, i.e. inp[idx, :] of shape (V,)\n\n    const floatX* x = inp + idx * P;\n    float thread_maxval = -INFINITY;\n    float thread_sumval = 0.0f;\n    int i = (V+x128::size-1)/x128::size + threadIdx.x - blockDim.x;\n\n    // special-case loop to handle the unaligned elements at the end of the array\n    // this lets us skip the bounds check in the main loop below, which improves performance\n    while ((i+1)*x128::size > V) {\n        for(int k = 0; k < x128::size; ++k) {\n            if (i*x128::size+k >= V) {\n                break; // bounds checking against real V (rather than padded P)\n            }\n            float v = (float)x[i*x128::size+k];\n            float old_maxval = thread_maxval;\n            thread_maxval = fmaxf(thread_maxval, v);\n            thread_sumval *= expf((old_maxval - thread_maxval));\n            thread_sumval += expf(v - thread_maxval);\n        }\n        i -= blockDim.x;\n    }\n\n    // main loop for the bulk of the iterations (no bounds checking required!)\n    for (; i >= 0; i -= blockDim.x) {\n        x128 packed_x = load128(x + i * x128::size); // load and keep in cache until fused_classifier loop\n        for(int k = 0; k < x128::size; ++k) {\n            float v = (float)packed_x[k];\n            float old_maxval = thread_maxval;\n            thread_maxval = fmaxf(thread_maxval, v);\n            thread_sumval *= expf((old_maxval - thread_maxval));\n            thread_sumval += expf(v - thread_maxval);\n        }\n    }\n\n    // Block Max Reduction -> Maths -> Block Sum Reduction\n    float block_maxval = blockReduce<warpReduceMax>(thread_maxval, false, -INFINITY);\n    thread_sumval *= expf(thread_maxval - block_maxval);\n    float block_sumval = blockReduce<warpReduceSum>(thread_sumval);\n\n    // return the softmax parameters\n    return SoftmaxParams{1.f / block_sumval, block_maxval};\n}\n\n// will _update_ logits to logit gradients\n// uses template to decide whether to write logits and probs\n// split both loops in \"multiple-of-x128-size\" and \"bounds-checked remainder\" parts\ntemplate <bool WriteDLogits = true, bool WriteProbs = false>\n__global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS)\n    fused_classifier_kernel5(floatX* logits, float* losses, floatX* probs,\n                                const float dloss, const int* targets,\n                                int B, int T, int V, int P, std::bool_constant<WriteDLogits>) {\n    // note: idx is small enough that it easily fits into 32 bit;\n    // by making it a long here, we ensure that any offsets calculated with it (e.g., idx * P)\n    // are done is 64 bit\n    int64_t idx = gridDim.x - (blockIdx.x+1); // reverse order for cache hits on matmul data\n    int ix = targets[idx];\n\n    // softmax (reading B * T * V, same logits read again below, hopefully still in cache)\n    SoftmaxParams sp = prepare_softmax_blockwide3(idx, logits, V, P);\n\n    // calculate the probability needed for the loss and update (single-threaded)\n    if(threadIdx.x == 0) {\n        float prob = expf((float)logits[idx * P + ix] - sp.Offset) * sp.Scale;\n        losses[idx] -= logf(prob);\n    }\n\n    // without this synchronization point we have a race condition:\n    // the logits used above to compute the loss are concurrently (race) modified to carry backward pass grads.\n    // since the \"logits\" are overwritten to be in the [-1, 1] range and sp.Offset is sometimes smaller than -90\n    // we errouneously end up computing exp^(90+) which gives us infinities in the loss! this is the fix.\n    __syncthreads();\n\n    // calculate the gradients directly, saves bandwidth from probs during training\n    // but also supports writing probs for inference-only and debugging\n    const floatX* logits_vec = logits + idx * P;\n    for (int i = threadIdx.x; i < V/x128::size; i += blockDim.x) {\n        // this is the 2nd read of logits after the one in prepare_softmax2\n        // it will be overwritten by the logits gradients which is when we reduce cache persistence\n        x128 packed_logits_vec = load128(logits_vec + i * x128::size); // rely on cs of store128cs\n        x128 packed_probs;\n        for(int k = 0; k < x128::size; ++k) {\n            int element = i*x128::size + k;\n            float prob = expf((float)packed_logits_vec[k] - sp.Offset) * sp.Scale;\n            packed_probs[k] = (floatX)prob;\n            float indicator = (element == ix) ? 1.0f : 0.0f;\n            packed_logits_vec[k] = (floatX)((prob - indicator) * dloss);\n        }\n        if (WriteDLogits){\n            // reduce cache persistence for the overwritten logits\n            // to maximise probability that logits remain in cache between prepare_softmax and here\n            store128cs(logits + idx * P + i * x128::size, packed_logits_vec);\n        }\n        if (WriteProbs) {\n            store128(probs + idx * P + i * x128::size, packed_probs);\n        }\n    }\n\n    // handle remaining elements after the last multiple of x128::size\n    // e.g. if V = 8003, and x128::size = 8, we need to handle the last 3 elements\n    int unaligned_start = V & ~(x128::size - 1); // round down to multiple of x128::size\n    for (int i = threadIdx.x + unaligned_start; i < V; i++) {\n        float prob = expf((float)logits_vec[i] - sp.Offset) * sp.Scale;\n        float indicator = (i == ix) ? 1.0f : 0.0f;\n        float dlogit = (prob - indicator) * dloss;\n        if (WriteDLogits){\n            __stcs(logits + idx * P + i, (floatX)dlogit);\n        }\n        if (WriteProbs) {\n            probs[idx * P + i] = (floatX)prob;\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launchers\n\n// replaces logits with logit gradients\ntemplate <typename Type, bool WriteDLogits>\nvoid fused_classifier(Type* logits, float* losses,\n                      const float dloss, const int* targets,\n                      int B, int T, int V, int P, std::bool_constant<WriteDLogits> write_dlogits, cudaStream_t stream) {\n    NVTX_RANGE_FN();\n    const int block_size = 1024;\n    const int N = B * T;\n    const int grid_size = N;\n    fused_classifier_kernel5<<<grid_size, block_size, 0, stream>>>(logits, losses, (floatX*)NULL, dloss, targets, B, T, V, P, write_dlogits);\n    cudaCheck(cudaGetLastError());\n}\n"
  },
  {
    "path": "llmc/gelu.cuh",
    "content": "/*\n(Approximate) GeLU non-linearity layer\n*/\n#include <assert.h>\n// llmc internal imports\n#include \"cuda_common.h\"\n#include \"cuda_utils.cuh\"\n\n// ----------------------------------------------------------------------------\n// CUDA kernels\n\n#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI)\n__global__ void gelu_forward_kernel2(floatX* out, const floatX* inp) {\n    int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;\n\n    x128 packed_out;\n    x128 packed_inp = load128cs(inp + idx); // load and do not keep in cache\n    for(int k = 0; k < packed_inp.size; ++k) {\n        float xi = (float)packed_inp[k];\n        float cube = 0.044715f * xi * xi * xi;\n        packed_out[k] = (floatX)(0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube))));\n    }\n    // store instead of storecs (without cache streaming) in case it is useful for the\n    // data to be in the cache for the next operation after this GeLU\n    store128(out + idx, packed_out);\n}\n\n__global__ void gelu_backward_inplace_kernel(floatX* d_in_out, const floatX* inp) {\n    int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;\n\n    x128 packed_dinp;\n    x128 packed_inp = load128cs(inp + idx);\n    x128 packed_dout = load128(d_in_out + idx);\n    for (int k = 0; k < packed_inp.size; ++k) {\n        float x = (float)packed_inp[k];\n        float cube = 0.044715f * x * x * x;\n        float tanh_arg = GELU_SCALING_FACTOR * (x + cube);\n        float tanh_out = tanhf(tanh_arg);\n        float coshf_out = coshf(tanh_arg);\n        float sech_out = 1.0f / (coshf_out * coshf_out);\n        float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);\n        packed_dinp[k] = (floatX)(local_grad * (float)packed_dout[k]);\n    }\n    store128(d_in_out + idx, packed_dinp);\n}\n\n// ----------------------------------------------------------------------------\n// kernel launchers\n\nvoid gelu_forward(floatX* out, const floatX* inp, int N, cudaStream_t stream) {\n    NVTX_RANGE_FN();\n    const int block_size = 512;\n    assert(N % (block_size * x128::size) == 0);\n    const int grid_size = CEIL_DIV(N, block_size * x128::size);\n    gelu_forward_kernel2<<<grid_size, block_size, 0, stream>>>(out, inp);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid gelu_backward_inplace(floatX* d_in_out, const floatX* inp, const int N, cudaStream_t stream) {\n    NVTX_RANGE_FN();\n    const int block_size = 128;\n    assert(N % (block_size * x128::size) == 0);\n    const int grid_size = CEIL_DIV(N, block_size * x128::size);\n    gelu_backward_inplace_kernel<<<grid_size, block_size, 0, stream>>>(d_in_out, inp);\n    cudaCheck(cudaGetLastError());\n}\n"
  },
  {
    "path": "llmc/global_norm.cuh",
    "content": "/*\nGlobal norm, used in gradient clipping\n*/\n#include <assert.h>\n#include <stddef.h>\n#include <cuda_runtime_api.h>\n// llmc internal imports\n#include \"cuda_common.h\"\n#include \"cuda_utils.cuh\"\n\n// ----------------------------------------------------------------------------\n// CUDA kernels\n\ntemplate<class T>\n__device__ float global_norm_squared_for_range(const T* data, size_t count) {\n    size_t index = blockIdx.x * blockDim.x + threadIdx.x;\n    size_t grid_width = blockDim.x * gridDim.x;\n    float accumulator = 0.f;\n    for(size_t i = index; i < count; i += grid_width) {\n        accumulator += (float)data[i] * (float)data[i];\n    }\n    // block-level reduce\n    return blockReduce<warpReduceSum>(accumulator);\n}\n\ntemplate<class T>\n__global__ void global_norm_squared_kernel(float* out, const T* data, size_t count, ptrdiff_t stride) {\n    float block_sum = global_norm_squared_for_range(data + blockIdx.y * stride, count);\n    // each block accumulates its partial sum to out[out_index]\n    // we want to avoid using atomic add here so we combine this kernel with another kernel call\n    // that sums up the partial block sums\n    if(threadIdx.x == 0) {\n        size_t out_index = blockIdx.y * gridDim.x + blockIdx.x;\n        out[out_index] = out[out_index] + block_sum;\n    }\n}\n\n__global__ void global_norm_aggregate_kernel(float* out, size_t grid_size) {\n    size_t index = threadIdx.x;\n    // grab block sums from the previous kernel, use 0. as the neutral sum element\n    float block_sum = (index < grid_size) ? out[index] : 0.f;\n    float sum = blockReduce<warpReduceSum>(block_sum);\n    if(threadIdx.x == 0) {\n        out[0] = sum;  // out[0] ends up with the final norm squared\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launcher\n\n// Helper function determines the maximum number of block sums\nint get_max_num_block_sums(int* num_slices_all, int numel) {\n    // NOTE: this needs to be kept in sync with `global_norm_squared` below.\n    const int block_size = 512;\n    const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size;\n    assert(grid_size > 0);\n    int max_num_block_sums = 0;\n    for (int i = 0; i < numel; i++) {\n        int num_slices = num_slices_all[i];\n        const int gx = CEIL_DIV(grid_size, num_slices);\n        const int gy = num_slices;\n        max_num_block_sums = max(max_num_block_sums, gx * gy);\n    }\n\n    return max_num_block_sums;\n}\n\ntemplate<typename T>\nvoid global_norm_squared(float* out, const T* values, size_t count, ptrdiff_t stride, int num_slices, int max_num_block_sums, bool reset, cudaStream_t stream) {\n    const int block_size = 512;\n    // launch just enough blocks to fill the grid. deliberately no DIV_CEIL.\n    // having one block less than possible is a tiny performance hit, having\n    // one block too many is catastrophic, since it only can start once all the other\n    // blocks finish. anyway, I think cuda_threads_per_SM should be a multiple of 512\n    // on all gpus, so the division really is going to be exact.\n    const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size;\n    assert(grid_size > 0);      // gives a better error than letting the call below fail\n\n    const int gx = CEIL_DIV(grid_size, num_slices);\n    const int gy = num_slices;\n\n    assert(gx * gy < 1024);  // we want to later accumulate the block sums in a single block\n\n    if (reset) {\n        cudaCheck(cudaMemsetAsync(out, 0, max_num_block_sums * sizeof(float), stream));\n    }\n    global_norm_squared_kernel<<<dim3(gx, gy), block_size, 0, stream>>>(out, values, count, stride);\n    cudaCheck(cudaGetLastError());\n}\n"
  },
  {
    "path": "llmc/layernorm.cuh",
    "content": "/*\nLayerNorm CUDA kernel, and also Residual, because sometimes they are fused\n\nNote in llm.c we try to be clever in the backward pass to conserve memory.\nAll parameters use a += in the backward pass, so we can do gradient accumulation.\nBut all activations have = instead of += because these are faster (just read, no write).\nThis is okay for all activations except for those in the residual stream, where the\ngradients have to add. We make sure that we do a += as necessary.\nE.g., the layernorms are connected to the residuals so we += in layernorm backward.\n*/\n\n#include <assert.h>\n// llmc internal imports\n#include \"cuda_common.h\"\n#include \"cuda_utils.cuh\"\n\n// ----------------------------------------------------------------------------\n// CUDA kernels\n\n__global__ void layernorm_forward_kernel3(floatX* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd,\n                                    const floatX*  __restrict__ inp, const floatX*  __restrict__ weight,\n                                    const floatX* __restrict__ bias, int N, int C) {\n    int lane_id = threadIdx.x % WARP_SIZE;\n    int warp_id = threadIdx.x / WARP_SIZE;\n    int num_warps = blockDim.x / WARP_SIZE;\n\n    int idx = blockIdx.x * num_warps + warp_id;\n    if(idx >= N) { return; } // guard\n\n    // the row of input that this group of threads is responsible for\n    const floatX* x = inp + idx * C;\n\n    // mean\n    float sum = 0.0f;\n    for (int i = lane_id; i < C; i += WARP_SIZE) {\n        sum += (float)x[i];\n    }\n    sum = warpReduceSum(sum);\n    float m = sum / C;\n    if(lane_id == 0 && mean != nullptr) {\n        __stcs(mean + idx, m);\n    }\n\n    // rstd\n    sum = 0.0f;\n    for (int i = lane_id; i < C; i += WARP_SIZE) {\n        float diff = (float)x[i] - m;\n        sum += diff * diff;\n    }\n    sum = warpReduceSum(sum);\n    float s = rsqrtf(sum / C + 1e-5f);\n    if(lane_id == 0 && rstd != nullptr) {\n        __stcs(rstd + idx, s);\n    }\n\n    // final normalization and scaling by weight/bias\n    floatX* o = out + idx * C;\n    for (int c = lane_id; c < C; c += WARP_SIZE) {\n        // load and store using the .cs \"streaming\" hint to the compiler,\n        // indicating that this data will not be reused soon, and can be streamed through the caches\n        // this allows the threads to get more cache-hits for the (shared) weight and bias parameters\n        float n = s * ((float)__ldcs(x+c) - m);\n        __stcs(o+c, (floatX)(n * (float)weight[c] + (float)bias[c]));\n    }\n}\n\n__global__ void layernorm_forward_kernel6(floatX* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd,\n                                    const floatX*  __restrict__ inp, const floatX*  __restrict__ weight,\n                                    const floatX* __restrict__ bias, int N, int C) {\n    assert(blockDim.x == WARP_SIZE);\n\n    // load weights and biases into shared memory\n    // do this before we allow any threads to exit!\n    extern __shared__ char* params[];\n    // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so\n    // let's keep everything as x128\n    x128* s_weight = reinterpret_cast<x128*>(params);\n    x128* s_bias = reinterpret_cast<x128*>(params) + (C / x128::size);\n    x128* s_in = reinterpret_cast<x128*>(params) + ((2 + threadIdx.y) * C / x128::size);\n\n    int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size;\n    for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) {\n        s_weight[i/x128::size] = load128(weight + i);\n        s_bias[i/x128::size] = load128(bias + i);\n    }\n    __syncthreads();\n\n    int idx = blockIdx.x * blockDim.y + threadIdx.y;\n    if(idx >= N) { return; } // guard\n\n    // adjust pointers to current token\n    inp += idx * C;\n    out += idx * C;\n\n    const float eps = 1e-5f;\n    float sum = 0.0f;\n    for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) {\n        const x128 in_data = load128cs(inp + c);\n        for(int k = 0; k < x128::size; ++k) {\n            sum += (float)in_data[k];\n        }\n        s_in[c / x128::size] = in_data;\n    }\n\n    sum = warpReduceSum(sum);\n    float m = sum / C;\n    float v = 0.f;\n\n    for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) {\n        const x128 in_data = s_in[c / x128::size];\n        for(int k = 0; k < x128::size; ++k) {\n            v += ((float)in_data[k] - m) * ((float)in_data[k] - m);\n        }\n    }\n\n    v = warpReduceSum(v) / C;\n    float s = rsqrtf(v + eps);\n\n    for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) {\n        const x128 in_data = s_in[c / x128::size];\n        const x128 w = s_weight[c / x128::size];\n        const x128 b = s_bias[c / x128::size];\n        x128 out_data;\n        for(int k = 0; k < x128::size; ++k) {\n            float n = s * ((float)in_data[k] - m); // normalized output\n            float o = n * (float)w[k] + (float)b[k]; // scale and shift it\n            out_data[k] = (floatX)o;\n        }\n\n        store128cs(out + c, out_data);\n    }\n    // cache the mean and rstd for the backward pass later\n    if(threadIdx.x == 0 && mean != nullptr) {\n        __stcs(mean + idx, m);\n    }\n    // store the rstd, no need to cache it\n    if(threadIdx.x == 0 && rstd != nullptr) {\n        __stcs(rstd + idx, s);\n    }\n}\n\n__global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, float* mean, float* rstd,\n                                               const floatX* inp1, const floatX* inp2,\n                                               const floatX* weight, const floatX* bias,\n                                               int N, int C) {\n    assert(blockDim.x == WARP_SIZE);\n\n    // load weights and biases into shared memory\n    // do this before we allow any threads to exit!\n    extern __shared__ char* params[];\n    // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so\n    // let's keep everything as x128\n    x128* s_weight = reinterpret_cast<x128*>(params);\n    x128* s_bias = reinterpret_cast<x128*>(params) + (C / x128::size);\n    x128* s_res = reinterpret_cast<x128*>(params) + ((2 + threadIdx.y) * C / x128::size);\n\n    int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size;\n    for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) {\n        s_weight[i/x128::size] = load128(weight + i);\n        s_bias[i/x128::size] = load128(bias + i);\n    }\n    __syncthreads();\n\n    int idx = blockIdx.x * blockDim.y + threadIdx.y;\n    if(idx > N) return;\n\n    // adjust pointers to current token\n    residual += C * idx;\n    normed += C * idx;\n    inp1 += C * idx;\n    inp2 += C * idx;\n\n    const float eps = 1e-5f;\n    float sum = 0.0f;\n    for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) {\n        const x128 in1 = load128cs(inp1 + c);\n        const x128 in2 = load128cs(inp2 + c);\n        x128 out;\n        for(int k = 0; k < x128::size; ++k) {\n            out[k] = (float)in1[k] + (float)in2[k];\n            sum += (float)out[k];\n        }\n        store128cs(residual + c, out);\n        s_res[c / x128::size] = out;\n    }\n\n    sum = warpReduceSum(sum);\n    float m = sum / C;\n    float v = 0.f;\n\n    for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) {\n        const x128 res = s_res[c / x128::size];\n        for(int k = 0; k < x128::size; ++k) {\n            v += ((float)res[k] - m) * ((float)res[k] - m);\n        }\n    }\n\n    v = warpReduceSum(v) / C;\n    float s = rsqrtf(v + eps);\n\n    for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) {\n        const x128 res = s_res[c / x128::size];\n        const x128 w = s_weight[c / x128::size];\n        const x128 b = s_bias[c / x128::size];\n        x128 out;\n        for(int k = 0; k < x128::size; ++k) {\n            float n = s * ((float)res[k] - m); // normalized output\n            float o = n * (float)w[k] + (float)b[k]; // scale and shift it\n            out[k] = o;\n        }\n\n        store128cs(normed + c, out);\n    }\n    // cache the mean and rstd for the backward pass later\n    if(threadIdx.x == 0) {\n        mean[idx] = m;\n        rstd[idx] = s;\n    }\n}\n\n__global__ void residual_forward_kernel(floatX* out, const floatX* inp1, const floatX* inp2) {\n    int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;\n\n    x128 packed_out;\n    x128 packed_inp1 = load128cs(inp1 + idx);\n    x128 packed_inp2 = load128cs(inp2 + idx);\n    for (int k = 0; k < packed_inp1.size; k++) {\n        packed_out[k] = (floatX)((float)packed_inp1[k] + (float)packed_inp2[k]);\n    }\n    store128(out + idx, packed_out);\n}\n\n__global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with only 1024 threads?\n    layernorm_backward_kernel10(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch,\n                                const floatX* dout, const floatX* inp, const floatX* weight,\n                                const float* mean, const float* rstd,\n                                int B, int T, int C) {\n    int BLOCK_SIZE = blockDim.x;\n    int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block\n    extern __shared__ float shared[];\n\n    int warpId = threadIdx.x / WARP_SIZE; // warp index within a block\n    int baseIdx = blockIdx.x * warpsInBlock + warpId;\n    int warpThreadIdx = threadIdx.x % WARP_SIZE; // Thread index within the warp\n    int warpsInGrid = gridDim.x * warpsInBlock;\n    int C_per_iteration = WARP_SIZE * x128::size;\n    int iterations_C = CEIL_DIV(C, C_per_iteration); // + 2;\n\n    // the first half of shared memory is bias, second is weight\n    size_t rounded_C = CEIL_DIV(C, (32 * x128::size)) * (32 * x128::size);\n    float* dbias_shared = shared;\n    float* dweight_shared = shared + rounded_C;\n    // warp zero doesn't actually write to the _tmp_shared memory locations, so we don't need to reserve memory\n    // the obvious solution is to change the addressing below to use (threadId.x-32) as offset, but that causes\n    // register spills, so instead we mess with the base pointer here, which doesn't increase register usage.\n    float* dbias_tmp_shared = shared + 2 * rounded_C - WARP_SIZE * f128::size;\n    float* dweight_tmp_shared = shared + 2 * rounded_C + f128::size * BLOCK_SIZE - 2 * WARP_SIZE * f128::size;\n\n    // init shared memory to zero\n    for(int i = threadIdx.x * f128::size; i < rounded_C; i += BLOCK_SIZE * f128::size) {\n        store128(dbias_shared + i, f128::zeros());\n        store128(dweight_shared + i, f128::zeros());\n    }\n    __syncthreads();\n\n    for (int bt = baseIdx; bt < B * T; bt += warpsInGrid) {\n        const floatX* dout_bt = dout + bt * C;\n        const floatX* inp_bt = inp +bt * C;\n        floatX* dinp_bt = dinp + bt * C;\n\n        // first: two reduce operations\n        float dnorm_mean = 0.0f;\n        float dnorm_norm_mean = 0.0f;\n        for (int i = warpThreadIdx * x128::size; i < C; i += WARP_SIZE * x128::size) {\n            x128 dout128_i   = load128(dout_bt + i);\n            x128 inp128_i    = load128(inp_bt  + i);\n            x128 weight128_i = load128(weight  + i);\n            for (int k = 0; k < x128::size; k++) {\n                float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k];\n                dnorm_mean += dnorm_i;\n                dnorm_norm_mean += dnorm_i * (float)inp128_i[k];\n            }\n        }\n\n        const float mean_bt = mean[bt];\n        const float rstd_bt = rstd[bt];\n        dnorm_mean = warpReduceSum(dnorm_mean) / C;\n        dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C * rstd_bt - dnorm_mean * mean_bt * rstd_bt;\n\n        for (int c = 0; c < iterations_C; c++) {\n            int global_index = (warpThreadIdx * x128::size) + (c * C_per_iteration);\n\n            x128 dout128   = x128::zeros();\n            x128 inp128    = x128::zeros();\n            x128 dinp128   = x128::zeros();\n            x128 weight128 = x128::zeros();\n\n            if(global_index < C) {\n                dout128 = load128cs(dout_bt + global_index);\n                inp128 = load128cs(inp_bt + global_index);\n                dinp128 = load128(dinp_bt + global_index);\n                weight128 = load128(weight + global_index);\n            }\n\n            for(int o = 0; o < x128::size / f128::size; ++o) {\n                f128 dbias_f;\n                f128 dweight_f;\n                for(int i = 0; i < f128::size; ++i) {\n                    int x = o * f128::size + i;\n                    float dout_i = (float)dout128[x];\n                    float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt;\n                    dbias_f[i] = dout_i;\n                    dweight_f[i] = norm_bti * dout_i;\n\n                    float dval = 0.0f;\n                    dval += (float) weight128[x] * (float)dout128[x]; // term 1\n                    dval -= dnorm_mean; // term 2\n                    dval -= norm_bti * dnorm_norm_mean; // term 3\n                    dval *= rstd_bt; // final scale\n                    dinp128[x] = (floatX) ((float) dinp128[x] + dval);\n                }\n\n                if (warpId != 0) {\n                    store128(dbias_tmp_shared + threadIdx.x * f128::size, dbias_f);\n                    // this seems to generate a 64-bit store, instead of 128-bit.\n                    // however, forcing 128-bit (e.g., using inline ptx), results in register\n                    // spilling and much worse performance, so we'll keep it like this for now\n                    // but ideally, we could reduce the register pressure a little.\n                    store128(dweight_tmp_shared + threadIdx.x * f128::size, dweight_f);\n                }\n                __syncthreads();\n                if (warpId == 0) {\n                    for (int j = 1; j < warpsInBlock; j++) {\n                        f128 dbias_tmp = load128(dbias_tmp_shared + f128::size * (threadIdx.x + j * WARP_SIZE));\n                        f128 dweight_tmp = load128(dweight_tmp_shared + f128::size * (threadIdx.x + j * WARP_SIZE));\n                        for(int i = 0; i < f128::size; ++i) {\n                            dbias_f[i] += dbias_tmp[i];\n                            dweight_f[i] += dweight_tmp[i];\n                        }\n                    }\n                }\n                __syncthreads();\n                if (warpId == 0) {\n                    f128 db_old = load128(dbias_shared + global_index + f128::size * o);\n                    f128 dw_old = load128(dweight_shared + global_index + f128::size * o);\n                    for(int i = 0; i < f128::size; ++i) {\n                        dbias_f[i] += db_old[i];\n                        dweight_f[i] += dw_old[i];\n                    }\n                    store128(dbias_shared + global_index + f128::size * o, dbias_f);\n                    store128(dweight_shared + global_index + f128::size * o, dweight_f);\n                }\n            }\n            if(global_index < C) {\n                // cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing\n                store128cg(dinp_bt + global_index, dinp128);\n            }\n        }\n    }\n    __syncthreads();\n    // Each block writes its partial sum to global memory\n    // The last block to finish becomes responsible for summing up all the partial sums\n    // This is done by atomically incrementing a flag (cleared to 0 before launching the kernel)\n    unsigned int* scratchFlag = (unsigned int*)(scratch);\n    // Increment scratch pointer by a full cacheline so that everything remains cacheline aligned\n    scratch += 32;\n    float* scratch_dbias = scratch;\n    float* scratch_dweight = scratch + C;\n    for(int i = threadIdx.x * f128::size; i < C; i += BLOCK_SIZE * f128::size) {\n        // Write to global memory in the same \"shared memory banking friendly\" order\n        store128(scratch_dbias + i + 2*C*blockIdx.x, load128(dbias_shared + i));\n        store128(scratch_dweight + i + 2*C*blockIdx.x, load128(dweight_shared + i));\n    }\n    __syncthreads();\n    // that portion of shared memory is no longer used, so we can repurpose it for the scratch flag.\n    unsigned int *tmp_flag = (unsigned int*)(shared + 2*rounded_C);\n    if (threadIdx.x == 0) {\n        *tmp_flag = atomicInc(scratchFlag, gridDim.x);\n    }\n    __syncthreads();\n    if (*tmp_flag == gridDim.x-1) {\n        // Reduction of the partial sums by the final block\n        // todo - there isn't enough parallelism even inside that single SM...\n        // ==> so could maybe split into another kernel with YET ANOTHER level of reduction?!\n        for(int i = threadIdx.x * f128::size; i < C; i += BLOCK_SIZE * f128::size) {\n            f128 dbias_accum = f128::zeros();\n            f128 dweight_accum = f128::zeros();\n\n            for (int read_block_idx = 0; read_block_idx < gridDim.x; read_block_idx++) {\n                int offset = i + 2*C*read_block_idx;\n                f128 dbias128 = load128(scratch_dbias + offset);\n                f128 dweight128 = load128(scratch_dweight + offset);\n                for(int k = 0; k < f128::size; k++) {\n                    dbias_accum[k] += dbias128[k];\n                    dweight_accum[k] += dweight128[k];\n                }\n            }\n            store128(dbias_shared + i, dbias_accum);\n            store128(dweight_shared + i, dweight_accum);\n        }\n        __syncthreads();\n\n        // convert from float/FP32 to floatX/BF16 for the final write\n        // this is separate because it cannot use as many warps as the above (f128 vs x128)\n        // todo - if we split this code into another kernel, we could maybe do it at the same time?\n        for (int c = warpId; c < iterations_C; c += warpsInBlock) {\n            int global_index = (warpThreadIdx * x128::size) + (c * C_per_iteration);\n            if (global_index >= C) {\n                break;\n            }\n\n            x128 dbias128 = load128(dbias + global_index);\n            x128 dweight128 = load128(dweight + global_index);\n            for(int o = 0; o < x128::size / f128::size; ++o) {\n                f128 s_db = load128(dbias_shared + global_index + o * f128::size);\n                f128 s_dw = load128(dweight_shared + global_index + o * f128::size);\n                for(int i = 0; i < f128::size; ++i) {\n                    int x = o * f128::size + i;\n                    dbias128[x] = (floatX)(s_db[i] + (float)dbias128[x]);\n                    dweight128[x] = (floatX)(s_dw[i] + (float)dweight128[x]);\n                }\n            }\n            store128(dbias + global_index, dbias128);\n            store128(dweight + global_index, dweight128);\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launchers\n\n// similar to `fused_residual_forward5`\nvoid layernorm_forward(floatX* out, float* mean, float* rstd,\n                       floatX* inp, const floatX* weight, const floatX* bias,\n                       int B, int T, int C, cudaStream_t stream) {\n    NVTX_RANGE_FN();\n    const int block_size = 256;\n    int block_y = block_size / WARP_SIZE;\n    const int N = B * T;\n    const int grid_size = CEIL_DIV(N, block_y);\n    size_t smem = (2 + block_y) * C * sizeof(floatX);\n\n    // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute\n    // this may fail, in which case we fall back to the smem free implementation.\n    cudaCheck(cudaGetLastError());\n    auto status = cudaFuncSetAttribute(layernorm_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);\n    cudaCheck(cudaGetLastError());\n    if (status == cudaSuccess) {\n        layernorm_forward_kernel6<<<grid_size, dim3(WARP_SIZE, block_y), smem, stream>>>(out, mean, rstd, inp, weight, bias, N, C);\n    } else {\n        // fall back to the version without shared memory\n        const int grid_size_fb = CEIL_DIV(N * WARP_SIZE, block_size);\n        layernorm_forward_kernel3<<<grid_size_fb, block_size, 0, stream>>>(out, mean, rstd, inp, weight, bias, N, C);\n    }\n    cudaCheck(cudaGetLastError());\n}\n\nvoid residual_forward(floatX* out, const floatX* inp1, const floatX* inp2, int N, cudaStream_t stream) {\n    NVTX_RANGE_FN();\n    const int block_size = 256;\n    assert(N % (block_size * x128::size) == 0);\n    const int grid_size = CEIL_DIV(N, block_size * x128::size);\n    residual_forward_kernel<<<grid_size, block_size, 0, stream>>>(out, inp1, inp2);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid fused_residual_forward5(floatX* residual, floatX* normed, float* mean, float* rstd,\n                             const floatX* inp1, const floatX* inp2,\n                             const floatX* weight, const floatX* bias,\n                             int N, int C, cudaStream_t stream) {\n    const int block_size = 256;\n    int block_y = block_size / WARP_SIZE;\n    const int grid_size = CEIL_DIV(N, block_y);\n    size_t smem = (2 + block_y) * C * sizeof(floatX);\n\n    // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute\n    // this may fail, in which case we fall back to the smem free implementation.\n    cudaCheck(cudaGetLastError());\n    auto status = cudaFuncSetAttribute(fused_residual_forward_kernel5, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);\n    cudaCheck(cudaGetLastError());\n    if(status == cudaSuccess) {\n        fused_residual_forward_kernel5<<<grid_size, dim3(WARP_SIZE, block_y), smem, stream>>>(residual, normed,\n                                                                                              mean, rstd, inp1, inp2,\n                                                                                              weight, bias, N, C);\n    } else {\n        residual_forward(residual, inp1, inp2, N*C, stream);\n        layernorm_forward(normed, mean, rstd, residual, weight, bias, N, 1, C, stream);\n    }\n    cudaCheck(cudaGetLastError());\n}\n\nvoid layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch,\n                        const floatX* dout, const floatX* inp, const floatX* weight, const float* mean, const float* rstd,\n                        int B, int T, int C, cudaStream_t stream) {\n    NVTX_RANGE_FN();\n    const int block_size = 512;\n    const int blocks_per_sm = 2; // supported on every architecture and less cache thrashing than 3\n    const int grid_size = blocks_per_sm * deviceProp.multiProcessorCount;\n    size_t rounded_C = CEIL_DIV(C, (32 * x128::size)) * (32 * x128::size);\n    size_t shared_mem_size = (2 * rounded_C + 2 * (block_size - 32) * f128::size) * sizeof(float);\n\n    cudaCheck(cudaMemsetAsync(scratch, 0, 1 * sizeof(float), stream)); // only need to reset the flag to 0\n    layernorm_backward_kernel10<<<grid_size, block_size, shared_mem_size, stream>>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C);\n    cudaCheck(cudaGetLastError());\n}\n"
  },
  {
    "path": "llmc/logger.h",
    "content": "/*\nImplements a simple logger that writes log files in the output directory.\nThe Logger object is stateless and uses append mode to write to log files.\n*/\n#ifndef LOGGER_H\n#define LOGGER_H\n\n#include <assert.h>\n#include <stdio.h>\n#include <string.h>\n// defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck\n#include \"utils.h\"\n\ntypedef struct {\n    int active;\n    char output_log_file[512];\n} Logger;\n\nvoid logger_init(Logger *logger, const char *log_dir, int process_rank, int resume) {\n    // currently, only rank 0 writes logs\n    logger->active = 0;\n    if (log_dir != NULL && process_rank == 0) {\n        logger->active = 1;\n        assert(strlen(log_dir) < 500); // being a bit lazy, could relax later\n        snprintf(logger->output_log_file, 512, \"%s/main.log\", log_dir);\n        if (resume == 0) {\n            // wipe any existing logfile clean if we're starting fresh\n            FILE *logfile = fopenCheck(logger->output_log_file, \"w\");\n            fclose(logfile);\n        }\n    }\n}\n\nvoid logger_log_eval(Logger *logger, int step, float val) {\n    if (logger->active == 1) {\n        FILE *logfile = fopenCheck(logger->output_log_file, \"a\");\n        fprintf(logfile, \"s:%d eval:%.4f\\n\", step, val);\n        fclose(logfile);\n    }\n}\n\nvoid logger_log_val(Logger *logger, int step, float val_loss) {\n    if (logger->active == 1) {\n        FILE *logfile = fopenCheck(logger->output_log_file, \"a\");\n        fprintf(logfile, \"s:%d tel:%.4f\\n\", step, val_loss);\n        fclose(logfile);\n    }\n}\n\nvoid logger_log_train(Logger *logger, int step, float train_loss, float learning_rate, float grad_norm) {\n    if (logger->active == 1) {\n        FILE *logfile = fopenCheck(logger->output_log_file, \"a\");\n        fprintf(logfile, \"s:%d trl:%.4f lr:%.6f norm:%.2f\\n\", step, train_loss, learning_rate, grad_norm);\n        fclose(logfile);\n    }\n}\n\n#endif"
  },
  {
    "path": "llmc/matmul.cuh",
    "content": "/*\nMatrix Multiplication, with help from cuBLASLt\n*/\n#include <assert.h>\n#include <type_traits>      // std::bool_constant\n// llmc internal imports\n#include \"cuda_common.h\"\n#include \"cuda_utils.cuh\"\n#include \"cublas_common.h\"\n// GELU can be either fused (cublasLt) or non-fused (gelu.h)\n#include \"gelu.cuh\"\n\n// ----------------------------------------------------------------------------\n// CUDA kernels\n\ntemplate<typename OutFloat, bool UseAuxBuffer>\n__global__ void matmul_backward_bias_kernel9(OutFloat* dbias, const floatX* dout, int B, int T, int OC,\n                                             std::bool_constant<UseAuxBuffer>) {\n    constexpr const int bdx = 4;\n    constexpr const int bdy = WARP_SIZE / bdx;\n    assert(blockDim.x == bdx);\n    assert(blockDim.y == bdy);\n\n    int warp_d = (int)threadIdx.x;\n    int warp_c = (int)threadIdx.y;\n    int block_d = (int)threadIdx.z;\n\n    const int OC_per_warp = bdy * x128::size;  // 64 at BF16\n\n    int local_oc = warp_c * x128::size;\n    int global_oc = blockIdx.x * OC_per_warp + local_oc;\n\n    int local_bt = warp_d + bdx * block_d;\n    int bt_per_block = bdx * blockDim.z;\n\n    float accumulators[x128::size];\n    for (int k = 0; k < x128::size; k++) {\n        accumulators[k] = 0.0f;\n    }\n\n    if(global_oc < OC) {\n        // sum up over all bt within registers\n        for (int idx = blockIdx.y * bt_per_block + local_bt; idx < B * T; idx += gridDim.y * bt_per_block) {\n            x128 packed_dout = load128(dout + global_oc + idx*OC);\n            for (int k = 0; k < x128::size; k++) {\n                accumulators[k] += (float)packed_dout[k];\n            }\n        }\n    }\n\n    __shared__ float sub_results[x128::size][WARP_SIZE][bdy];\n\n    // reduce within-warp results\n    for (int k = 0; k < x128::size; k++) {\n        float v = accumulators[k];\n        v += __shfl_down_sync(0xffffffff, v, 1, 4);\n        v += __shfl_down_sync(0xffffffff, v, 2, 4);\n        if(warp_d == 0) {\n            sub_results[k][block_d][warp_c] = v;\n        }\n    }\n    __syncthreads();\n\n    // block-wide reductions\n    for (int k = block_d; k < x128::size; k += blockDim.z) {\n        float a = 0.f;\n        for (int r = warp_d; r < blockDim.z; r += bdx) {\n            float v = sub_results[k][r][warp_c];\n            v += __shfl_down_sync(0xffffffff, v, 1, 4);\n            v += __shfl_down_sync(0xffffffff, v, 2, 4);\n            a += v;\n        }\n        if(warp_d == 0 && global_oc < OC) {\n            if constexpr (!UseAuxBuffer) {\n                dbias[global_oc + k] = (OutFloat)(a + (float)dbias[global_oc + k]);\n            } else {\n                dbias[global_oc + k + blockIdx.y * OC] = a;\n            }\n        }\n    }\n}\n\n__global__ void reduce_add_sum_kernel(floatX* dst, const float* src, size_t n, size_t m) {\n    const size_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * f128::size;\n    assert(n % x128::size == 0);\n    if (idx < n) {\n        f128 acc;\n        for(int k = 0; k < f128::size; ++k) {\n            acc[k] = 0.f;\n        }\n\n        for(int l = 0; l < m; ++l) {\n            f128 s = load128(src + idx + n * l);\n            for(int k = 0; k < f128::size; ++k) {\n                acc[k] += s[k];\n            }\n        }\n        for(int k = 0; k < f128::size; ++k) {\n            dst[idx + k] = (floatX) ((float)dst[idx + k] + acc[k]);\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// kernel launchers\n\n// Wrapper around cublasLtMatmul that is meant to support everything we need in llm.c\n// https://docs.nvidia.com/cuda/cublas/#cublasltmatmul\nvoid matmul_cublaslt(floatX* d, const floatX* a, const floatX* b, const floatX* bias,\n                     int m, int n, int k, cudaStream_t stream=0, bool transA=true, bool transB=false,\n                     int batch_count=0, size_t strideA=0, size_t strideB=0, size_t strideOut=0,\n                     bool accumulate=false, floatX* pre_gelu=NULL, bool backward=false)\n{\n    NVTX_RANGE_FN();\n    bool has_bias = (bias != NULL);\n    bool has_gelu = (pre_gelu != NULL);\n\n    // check alignment (some modes work unaligned but it always best to be aligned for performance)\n    if(((uintptr_t)a % 16) != 0 || ((uintptr_t)b % 16) != 0 || ((uintptr_t)d % 16) != 0 || ((uintptr_t)bias % 16) != 0) {\n        printf(\"All cuBLASLt pointers must be aligned!\\n\");\n        exit(EXIT_FAILURE);\n    }\n\n    // create the operation descriptor\n    cublasLtMatmulDesc_t operationDesc;\n    cublasCheck(cublasLtMatmulDescCreate(&operationDesc, cublas_compute, CUDA_R_32F));\n\n    int returnedResults = 0;\n    cublasLtMatmulPreference_t preference;\n    cublasLtMatmulHeuristicResult_t heuristic;\n\n    cublasOperation_t opNoTranspose = CUBLAS_OP_N;\n    cublasOperation_t opTranspose = CUBLAS_OP_T;\n    cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, (transA)  ? &opTranspose : &opNoTranspose,   sizeof(opTranspose)));\n    cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, (transB) ? &opTranspose   : &opNoTranspose, sizeof(opNoTranspose)));\n\n    // define matrix layouts\n    cublasLtMatrixLayout_t ALayout;\n    cublasLtMatrixLayout_t BLayout;\n    cublasLtMatrixLayout_t DLayout;\n    cublasLtMatrixLayout_t CLayout;\n    if (transA) {\n        cublasCheck(cublasLtMatrixLayoutCreate(&ALayout, CUBLAS_LOWP, k, m, k));\n    } else {\n        cublasCheck(cublasLtMatrixLayoutCreate(&ALayout, CUBLAS_LOWP, m, k, m));\n    }\n    if (transB) {\n        cublasCheck(cublasLtMatrixLayoutCreate(&BLayout, CUBLAS_LOWP, n, k, n));\n    } else {\n        cublasCheck(cublasLtMatrixLayoutCreate(&BLayout, CUBLAS_LOWP, k, n, k));\n    }\n    // cuBLASLt requires C in FP8 mode to be BF16 or FP32... (sigh)\n    cublasCheck(cublasLtMatrixLayoutCreate(&CLayout, (sizeof(floatX) == 1) ? CUDA_R_16BF : CUBLAS_LOWP, m, n, m));\n    cublasCheck(cublasLtMatrixLayoutCreate(&DLayout, CUBLAS_LOWP, m, n, m));\n\n    // Strided Batched GEMM (used for non-flash attention, equivalent to cublasGemmStridedBatchedEx)\n    if (batch_count) {\n        cublasCheck(cublasLtMatrixLayoutSetAttribute(ALayout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count)));\n        cublasCheck(cublasLtMatrixLayoutSetAttribute(BLayout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count)));\n        cublasCheck(cublasLtMatrixLayoutSetAttribute(CLayout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count)));\n        cublasCheck(cublasLtMatrixLayoutSetAttribute(DLayout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count)));\n\n        cublasCheck(cublasLtMatrixLayoutSetAttribute(ALayout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideA, sizeof(strideA)));\n        cublasCheck(cublasLtMatrixLayoutSetAttribute(BLayout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideB, sizeof(strideB)));\n        cublasCheck(cublasLtMatrixLayoutSetAttribute(CLayout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideOut, sizeof(strideOut)));\n        cublasCheck(cublasLtMatrixLayoutSetAttribute(DLayout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideOut, sizeof(strideOut)));\n    }\n\n    // create a preference handle with specified max workspace\n    cublasCheck(cublasLtMatmulPreferenceCreate(&preference));\n    cublasCheck(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,\n                                                     &cublaslt_workspace_size, sizeof(cublaslt_workspace_size)));\n\n    // setup epilogue and associated pointers for bias & gelu\n    cublasLtEpilogue_t epilogue;\n    if (has_gelu) {\n        int64_t gelu_ld = m; // todo - is this affected by anything else?\n        cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &gelu_ld, sizeof(gelu_ld)));\n        cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_gelu, sizeof(pre_gelu)));\n        if (backward) {\n            assert(!has_bias); // we shouldn't have any backward matmuls that use both GELU and bias\n            epilogue = CUBLASLT_EPILOGUE_DGELU;\n        } else {\n            epilogue = has_bias ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_AUX;\n        }\n    } else if(has_bias){\n        epilogue = backward ? CUBLASLT_EPILOGUE_BGRADB : CUBLASLT_EPILOGUE_BIAS;\n    } else {\n        epilogue = CUBLASLT_EPILOGUE_DEFAULT;\n    }\n    cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));\n\n    if (has_bias) {\n        // cuBLASLt requires bias in FP8 mode to be BF16... (sigh)\n        cublasDataType_t bias_data_type = (sizeof(floatX) == 1) ? CUDA_R_16BF : CUBLAS_LOWP; // force BF16 bias for FP8 mode\n        cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_data_type, sizeof(bias_data_type)));\n        cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));\n    }\n\n    // set scale type to FP32 (needs to be FP16 if and only if using CUBLAS_COMPUTE_16F, so it's FP32 even for FP8!)\n    cublasDataType_t scale_type = CUDA_R_32F;\n    cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));\n\n    // find a suitable algorithm (cached internally so shouldn't take much CPU time in practice)\n    cublasLtMatmulAlgoGetHeuristic(cublaslt_handle, operationDesc, ALayout, BLayout, CLayout, DLayout,\n                                   preference, 1, &heuristic, &returnedResults);\n    if (returnedResults == 0) {\n        printf(\"No cuBLASLt algorithm: m: %d, n: %d, k: %d, bias: %d\\n\", n, m, k, has_bias);\n        exit(EXIT_FAILURE);\n    }\n\n    // set whether to accumulate (i.e. D += C) or not - note this isn't considered in algorithm selection (?!)\n    const float alpha = 1.0f, beta = accumulate ? 1.0f : 0.0f;\n\n    // call the matmul\n    cublasCheck(cublasLtMatmul(cublaslt_handle, operationDesc,\n                               &alpha, a, ALayout, b, BLayout, &beta, d, CLayout, d, DLayout,\n                               &heuristic.algo, cublaslt_workspace, cublaslt_workspace_size, stream));\n\n    // cleanups\n    cublasCheck(cublasLtMatmulPreferenceDestroy(preference));\n    cublasCheck(cublasLtMatmulDescDestroy(operationDesc));\n    cublasCheck(cublasLtMatrixLayoutDestroy(ALayout));\n    cublasCheck(cublasLtMatrixLayoutDestroy(BLayout));\n    cublasCheck(cublasLtMatrixLayoutDestroy(CLayout));\n    cublasCheck(cublasLtMatrixLayoutDestroy(DLayout));\n    cudaCheck(cudaGetLastError());\n}\n\n// small wrapper around matmul_cublaslt for the forward pass (keeping historical order of arguments)\nvoid matmul_forward_cublaslt(floatX* out,\n                     floatX* inp, floatX* weight, floatX* bias,\n                     int B, int T, int C, int OC, cudaStream_t stream,\n                     floatX* pre_gelu=NULL, int gelu_fusion=1) {\n    // By default only fuse GELU for H100+ as cuBLAS seems to be inefficient for fused GELU on Ada/Ampere (?)\n    if (gelu_fusion < 1 && pre_gelu) {\n        matmul_cublaslt(pre_gelu, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, NULL, false);\n        gelu_forward(out, pre_gelu, B*T*OC, stream);\n    } else {\n        matmul_cublaslt(out, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, pre_gelu, false);\n    }\n}\n\nvoid matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias,\n                     floatX* dout, floatX* inp, floatX* weight,\n                     float* dbias_buffer,\n                     int B, int T, int C, int OC, cudaStream_t stream,\n                     floatX* pre_gelu=NULL, int gelu_fusion=1) {\n    NVTX_RANGE_FN();\n\n    // backward to bias, if given, does a +=\n    if (dbias != NULL) {\n        // Each warp is responsible for 8 * \"x128::size\" = 64 OCs at BF16 (OC must be a multiple of 64!)\n        // Block size is 1024 | 768 threads (32|24 warps) and we reduce those values into 1 at the end\n\n        const int block_size = deviceProp.maxThreadsPerMultiProcessor == 1536 ? 768 : 1024;\n\n        dim3 block_dim = {4, 8, (unsigned)block_size/WARP_SIZE};\n        const int OC_per_warp = block_dim.y * x128::size; // 64 at BF16\n        const int grid_size_x = CEIL_DIV(OC, OC_per_warp); // e.g. 12 horizontal blocks for 768 OCs at BF16\n        const int grid_size_y = max(1, deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / (block_size * grid_size_x)); // full GPU!\n\n        // If we have enough OC that we don't need cross-block reductions, we can skip the bias_buffer accumulation\n        // and write results directly to the output.\n        if(grid_size_y == 1) {\n            matmul_backward_bias_kernel9<<<dim3(grid_size_x, grid_size_y), block_dim, 0, stream>>>(dbias, dout, B, T, OC, False);\n            cudaCheck(cudaGetLastError());\n        } else {\n            // kernel 9 overwrites temp buffer, so no need to memset\n            matmul_backward_bias_kernel9<<<dim3(grid_size_x, grid_size_y), block_dim, 0, stream>>>(dbias_buffer, dout, B, T, OC, True);\n            cudaCheck(cudaGetLastError());\n            reduce_add_sum_kernel<<<CEIL_DIV(OC, 256 * f128::size), 256, 0, stream>>>(dbias, dbias_buffer, OC, grid_size_y);\n            cudaCheck(cudaGetLastError());\n        }\n        dbias = NULL; // prevent dbias calculation from also being fused in matmul_cublaslt below (if we enabled fusion)\n    }\n\n    // backward to input, uses = in the backward pass (set the gradient)\n    matmul_cublaslt(dinp, weight, dout, NULL, C, B*T, OC, stream, false, false, 0, 0, 0, 0, false,\n                    gelu_fusion >= 2 ? pre_gelu : NULL, true);\n\n    // backward GELU (if it wasn't fused into the matmul above)\n    if (gelu_fusion < 2 && pre_gelu) {\n        gelu_backward_inplace(dinp, pre_gelu, B*T*C, stream);\n    }\n\n    // backward to weight, uses += in the backward pass (accumulate the gradient) by setting alpha=one\n    matmul_cublaslt(dweight, inp, dout, NULL /*dbias*/, C, OC, B*T, stream, false, true, 0, 0, 0, 0,\n                    true /* accumulate */, NULL, true);\n}\n"
  },
  {
    "path": "llmc/mfu.h",
    "content": "#ifndef MFU_H\n#define MFU_H\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <string.h>\n#if __has_include(<nvml.h>)\n#define USE_NVML 1\n#include <nvml.h>\n#else\n#define USE_NVML 0\n#endif\n\n// tied to enum PrecisionMode, in a future refactor make them the same\n#define MFUH_PRECISION_FP32 0\n#define MFUH_PRECISION_FP16 1\n#define MFUH_PRECISION_BF16 2\n\n#if USE_NVML\ninline void nvml_check(nvmlReturn_t status, const char *file, int line) {\n    if (status != NVML_SUCCESS) {\n        printf(\"[NVML ERROR] at file %s:%d:\\n%s\\n\", file, line, nvmlErrorString(status));\n        exit(EXIT_FAILURE);\n    }\n};\n#define nvmlCheck(err) (nvml_check(err, __FILE__, __LINE__))\n#endif\n\n\ntypedef struct {\n    float TF_32;       // tensor-core performance 32 bit\n    float BF_16_32;    // bf16 with 32 bit accumulate\n    float FP_16_32;    // fp16 with 32 bit accumulate\n    float FP_16_16;    // fp16 with 16 bit accumulate\n    float FP_8_32;     // and so on\n    float FP_8_16;\n    float CLOCK;        // clock frequency from the spec sheet\n    float CORES;        // #TCs from the spec sheet\n} PerfData;\n\n// basic default data from the nvidia whitepapers\nstatic const PerfData VOLTA = {125.0f, -1.f, 125.f, -1.f, -1.f, -1.f, 1530.f, 640.f};\nstatic const PerfData AMPERE_DATACENTER = {156.f, 312.f, 312.f, 312.f, -1.f, -1.f, 1410.f, 432.f};\nstatic const PerfData AMPERE_CONSUMER = {40.f, 80.f, 80.f, 160.f, -1.f, -1.f, 1860.f, 336.f};\nstatic const PerfData HOPPER = {378.f, 756.f, 756.f, 756.f, 1513.f, 1513.f, 1620.f, 456.f};\nstatic const PerfData ADA = {82.6f, 165.2f, 165.2f, 330.3f, 330.3f, 660.6f, 2520.f, 512.f};\n\ntypedef struct {\n    const char* name;\n    const PerfData* perf_data;\n    float new_cores;\n    float new_mhz;\n} GPUEntry;\n\n// the overrides for each specific GPU\nstatic GPUEntry gpu_db[] = {\n    {\"Tesla V100-SXM2-16GB\", &VOLTA, 640, 1530},\n    {\"Tesla V100-PCIE-32GB\", &VOLTA, 640, 1530},\n    {\"NVIDIA A100-PCIE-40GB\", &AMPERE_DATACENTER, 432, 1410},\n    {\"NVIDIA A100-PCIE-80GB\", &AMPERE_DATACENTER, 432, 1410},\n    {\"NVIDIA A100-SXM4-40GB\", &AMPERE_DATACENTER, 432, 1410},\n    {\"NVIDIA A100-SXM4-80GB\", &AMPERE_DATACENTER, 432, 1410},\n    {\"NVIDIA RTX A2000\", &AMPERE_CONSUMER, 104, 1200},\n    {\"NVIDIA RTX A4000\", &AMPERE_CONSUMER, 192, 1560},\n    {\"NVIDIA RTX A4500\", &AMPERE_CONSUMER, 224, 1650},\n    {\"NVIDIA RTX A5000\", &AMPERE_CONSUMER, 256, 1695},\n    {\"NVIDIA RTX A5500\", &AMPERE_CONSUMER, 320, 1770},\n    {\"NVIDIA RTX A6000\", &AMPERE_CONSUMER, 336, 1800},\n    {\"NVIDIA GeForce RTX 3090 Ti\", &AMPERE_CONSUMER, 336, 1860},\n    {\"NVIDIA GeForce RTX 3090\", &AMPERE_CONSUMER, 328, 1695},\n    {\"NVIDIA GeForce RTX 3080 Ti\", &AMPERE_CONSUMER, 320, 1665},\n    {\"NVIDIA GeForce RTX 3080\", &AMPERE_CONSUMER, 272, 1710},\n    {\"NVIDIA GeForce RTX 3070 Ti\", &AMPERE_CONSUMER, 192, 1770},\n    {\"NVIDIA GeForce RTX 3070\", &AMPERE_CONSUMER, 184, 1725},\n    {\"NVIDIA GeForce RTX 3060 Ti\", &AMPERE_CONSUMER, 152, 1665},\n    {\"NVIDIA GeForce RTX 3060\", &AMPERE_CONSUMER, 112, 1777},\n    {\"NVIDIA RTX A2000 ADA\", &ADA, 88, 2130},\n    {\"NVIDIA RTX A4000 ADA\", &ADA, 192, 2175},\n    {\"NVIDIA RTX A4500 ADA\", &ADA, 224, 2580},\n    {\"NVIDIA RTX A5000 ADA\", &ADA, 400, 2550},\n    {\"NVIDIA RTX A5880 ADA\", &ADA, 440, 2460},\n    {\"NVIDIA RTX A6000 ADA\", &ADA, 568, 2505},\n    {\"NVIDIA GeForce RTX 4090\", &ADA, 512, 2520},\n    {\"NVIDIA GeForce RTX 4080 SUPER\", &ADA, 320, 2550},\n    {\"NVIDIA GeForce RTX 4080\", &ADA, 304, 2505},\n    {\"NVIDIA GeForce RTX 4070 Ti SUPER\", &ADA, 264, 2610},\n    {\"NVIDIA GeForce RTX 4070 Ti\", &ADA, 240, 2610},\n    {\"NVIDIA GeForce RTX 4070 SUPER\", &ADA, 224, 2475},\n    {\"NVIDIA GeForce RTX 4070\", &ADA, 184, 2475},\n    {\"NVIDIA GeForce RTX 4070\", &ADA, 184, 2475},\n    {\"NVIDIA GeForce RTX 4060 Ti\", &ADA, 136, 2535},\n    {\"NVIDIA GeForce RTX 4060\", &ADA, 96, 2460},\n    {\"NVIDIA H100 PCIe\", &HOPPER, 456, 1620},\n    {\"NVIDIA H100 80GB HBM3\", &HOPPER, 528, 1830}, // HBM3 = SXM5\n};\n\nfloat get_flops_promised(const char* device, int precision_mode) {\n    /*\n    This function is used to estimate the Model Flops Utilization (MFU)\n    basically we have to figure out how many flops the GPU can do per second.\n    Note that this is not a simple endeavor and may well go wrong! The details are tricky.\n    The returned value is in units of 1e12.\n\n    For the non-top models, actual performance numbers aren't that easy to find, e.g.,\n    here https://www.techpowerup.com/gpu-specs/rtx-a4000.c3756, does \"Theoretical Performance\"\n    seems to be without tensor cores.\n\n    So, instead we use that all these cards just use the same types of tensor cores in different\n    numbers and at different frequencies. Then we just need to look up these two easily accesible\n    numbers for all the other GPUs.\n    linear scaling seems to work: comparing spec sheet and calculation:\n    4080: 304TCs, 2505 GHz; 97.5TFlops = 165.2/512*304 /2520 * 2505\n\n    Original numbers for the top GPUS are from.\n    https://resources.nvidia.com/en-us-tensor-core\n    https://images.nvidia.com/aem-dam/Solutions/geforce/ada/nvidia-ada-gpu-architecture.pdf\n    */\n\n   // validate the precision mode as one of the three possible values\n    if (!(precision_mode == MFUH_PRECISION_FP32 || precision_mode == MFUH_PRECISION_FP16 || precision_mode == MFUH_PRECISION_BF16)) {\n        fprintf(stderr, \"Invalid precision mode: %d\\n\", precision_mode);\n        return -1.0f;\n    }\n\n    // do a linear search until you find our GPU, then calculate the flops promised\n    int num_gpu_entries = sizeof(gpu_db) / sizeof(gpu_db[0]);\n    for (int i = 0; i < num_gpu_entries; i++) {\n        if (strcmp(gpu_db[i].name, device) == 0) {\n            const PerfData* perf_data = gpu_db[i].perf_data;\n\n            // look up the default flops value for the given precision mode\n            float value = -1.0f;\n            if (precision_mode == MFUH_PRECISION_BF16) { value = perf_data->BF_16_32; }\n            if (precision_mode == MFUH_PRECISION_FP32) { value = perf_data->TF_32; }\n            if (precision_mode == MFUH_PRECISION_FP16) { value = perf_data->FP_16_32; }\n\n            // we'd get here if we're e.g. trying to use BF16 on Volta GPU or something...\n            if (value < 0.0f) {\n                fprintf(stderr, \"No data for GPU %s and precision mode %d\\n\", device, precision_mode);\n                return -1.0f;\n            }\n\n            // adjust flops based on the specific core count and clock frequency of this GPU\n            float new_cores = gpu_db[i].new_cores;\n            float new_mhz = gpu_db[i].new_mhz;\n            float adjusted = value * (new_cores / perf_data->CORES) * (new_mhz / perf_data->CLOCK);\n            return adjusted;\n        }\n    }\n\n    return -1.0f; // ¯\\_(ツ)_/¯\n}\n\nstruct GPUUtilInfo {\n    unsigned int clock;\n    unsigned int max_clock;\n    unsigned int power;\n    unsigned int power_limit;\n    unsigned int fan;\n    unsigned int temperature;\n    unsigned int temp_slowdown;\n\n    float gpu_utilization;\n    float mem_utilization;\n    const char* throttle_reason;\n};\n\n// lazily initialize nvml and generate a handle to the GPU\n#if USE_NVML\nnvmlDevice_t nvml_get_device() {\n    static bool needs_init = true;\n    static nvmlDevice_t device;\n    if(needs_init) {\n        needs_init = false;\n        nvmlCheck(nvmlInit());\n        nvmlCheck(nvmlDeviceGetHandleByIndex_v2(0, &device));\n    }\n    return device;\n}\n\n// convert throttle reason bitfield into a text reason.\n// this is a lossy conversion; we just want to give some idea of what is happening\nconst char* get_throttle_reason(unsigned long long bits) {\n    if(bits & (nvmlClocksThrottleReasonSwPowerCap | nvmlClocksThrottleReasonHwPowerBrakeSlowdown)) {\n        return \"power cap\";\n    } else if (bits & (nvmlClocksThrottleReasonSwThermalSlowdown | nvmlClocksThrottleReasonHwThermalSlowdown)) {\n        return \"thermal cap\";\n    } else if (bits & (nvmlClocksThrottleReasonAll)) {\n        return \"other cap\";\n    } else {\n        return \"no cap\";\n    }\n}\n\n// gather data for a GPUUtilInfo object\nGPUUtilInfo get_gpu_utilization_info() {\n    GPUUtilInfo info;\n    nvmlDevice_t device = nvml_get_device();\n    // query different infos directly\n    nvmlCheck(nvmlDeviceGetClockInfo(device, NVML_CLOCK_SM, &info.clock));\n    nvmlCheck(nvmlDeviceGetMaxClockInfo(device, NVML_CLOCK_SM, &info.max_clock));\n    nvmlCheck(nvmlDeviceGetPowerManagementLimit(device, &info.power_limit));\n    nvmlCheck(nvmlDeviceGetPowerUsage(device, &info.power));\n    nvmlCheck(nvmlDeviceGetTemperature(device, NVML_TEMPERATURE_GPU, &info.temperature));\n    nvmlCheck(nvmlDeviceGetTemperatureThreshold(device, NVML_TEMPERATURE_THRESHOLD_SLOWDOWN, &info.temp_slowdown));\n    unsigned long long throttle;\n    nvmlCheck(nvmlDeviceGetCurrentClocksThrottleReasons(device, &throttle));\n    info.throttle_reason = get_throttle_reason(throttle);\n    nvmlCheck(nvmlDeviceGetFanSpeed(device, &info.fan));\n\n    // for \"utilization\", we look at recorded samples. In principle, we could query the driver for how many samples\n    // to request, but then we'd need to dynamically allocate sufficient space. Let's just hard-code a limit of 128,\n    // and have no memory management required\n    constexpr const int BUFFER_LIMIT = 128;\n    nvmlSample_t buffer[BUFFER_LIMIT];\n    nvmlValueType_t v_type;\n    unsigned int sample_count = BUFFER_LIMIT;\n    nvmlCheck(nvmlDeviceGetSamples(device, NVML_GPU_UTILIZATION_SAMPLES, 0, &v_type, &sample_count, buffer));\n    float gpu_utilization = 0.f;\n    for(unsigned i = 0; i < sample_count; ++i) {\n        gpu_utilization += (float)buffer[i].sampleValue.uiVal;\n    }\n    gpu_utilization /= (float)sample_count;\n\n    // sample count may have been modified by the query above; reset back to buffer size\n    sample_count = BUFFER_LIMIT;\n    nvmlCheck(nvmlDeviceGetSamples(device, NVML_MEMORY_UTILIZATION_SAMPLES, 0, &v_type, &sample_count, buffer));\n    float mem_utilization = 0.f;\n    for(unsigned i = 0; i < sample_count; ++i) {\n        mem_utilization += (float)buffer[i].sampleValue.uiVal;\n    }\n    mem_utilization /= (float)sample_count;\n\n    info.gpu_utilization = gpu_utilization;\n    info.mem_utilization = mem_utilization;\n    return info;\n}\n#else\nGPUUtilInfo get_gpu_utilization_info() {\n    fprintf(stderr, \"Error: Compiled without nvml support. Cannot perform additional GPU state tracking.\");\n    exit(EXIT_FAILURE);\n}\n#endif\n#endif // MFU_H\n"
  },
  {
    "path": "llmc/outlier_detector.h",
    "content": "/*\nSimple OutlierDetector that we can use to monitor the loss and grad norm\nInternally, it keeps track of a window of measurements and each time we\nadd a measurement, it returns the z-score of the new value with respect to\nthe window of measurements. This can be used to detect outliers in the data.\n\nWe use double so that the detector doesn't drift too much, because we\nupdate the mean and variance with += on each step for efficiency. We could\nreconsider this choice in the future, as the compute cost here is minimal.\n*/\n\n#include <stdio.h>\n#include <math.h>\n\n// use compile-time constant for window size to avoid dynamic memory allocations\n#define OUTLIER_DETECTOR_WINDOW_SIZE 128\n\ntypedef struct {\n    double buffer[OUTLIER_DETECTOR_WINDOW_SIZE];\n    int count;\n    int index;\n    double sum;\n    double sum_sq;\n} OutlierDetector;\n\nvoid init_detector(OutlierDetector *detector) {\n    for (int i = 0; i < OUTLIER_DETECTOR_WINDOW_SIZE; i++) {\n        detector->buffer[i] = 0.0;\n    }\n    detector->count = 0;\n    detector->index = 0;\n    detector->sum = 0.0;\n    detector->sum_sq = 0.0;\n}\n\ndouble update_detector(OutlierDetector *detector, double new_value) {\n\n    if (detector->count < OUTLIER_DETECTOR_WINDOW_SIZE) {\n        // here we are still building up a window of observations\n        detector->buffer[detector->count] = new_value;\n        detector->sum += new_value;\n        detector->sum_sq += new_value * new_value;\n        detector->count++;\n        return nan(\"\"); // not enough data yet\n\n    } else {\n        // we've filled the window, so now we can start detecting outliers\n\n        // pop the oldest value from the window\n        double old_value = detector->buffer[detector->index];\n        detector->sum -= old_value;\n        detector->sum_sq -= old_value * old_value;\n        // push the new value into the window\n        detector->buffer[detector->index] = new_value;\n        detector->sum += new_value;\n        detector->sum_sq += new_value * new_value;\n        // move the index to the next position\n        detector->index = (detector->index + 1) % OUTLIER_DETECTOR_WINDOW_SIZE;\n        // calculate the z-score of the new value\n        double mean = detector->sum / OUTLIER_DETECTOR_WINDOW_SIZE;\n        double variance = (detector->sum_sq / OUTLIER_DETECTOR_WINDOW_SIZE) - (mean * mean);\n        double std_dev = sqrt(variance);\n        if (std_dev == 0.0) {\n            return 0.0;\n        }\n        double z = (new_value - mean) / std_dev;\n\n        return z;\n    }\n}\n"
  },
  {
    "path": "llmc/rand.h",
    "content": "/*\nMersenne Twisters implementation, numerically identical to torch.\n\nExample usage:\n\n    mt19937_state state;\n    manual_seed(&state, 137);\n    printf(\"%u\\n\", randint32(&state));\n    printf(\"%u\\n\", randint32(&state));\n    printf(\"%u\\n\", randint32(&state));\n    printf(\"%u\\n\", randint32(&state));\n    printf(\"%u\\n\", randint32(&state));\n\n    float t8[8];\n    normal_(t8, 8, 0, 1, &state);\n    for (int i = 0; i < 8; i++) {\n        printf(\"%f\\n\", t8[i]);\n    }\n    printf(\"%u\\n\", randint32(&state));\n\n    float t16[16];\n    normal_(t16, 16, 0, 1, &state);\n    for (int i = 0; i < 16; i++) {\n        printf(\"%f\\n\", t16[i]);\n    }\n    printf(\"%u\\n\", randint32(&state));\n\nPyTorch reference (producing identical results):\n\n    import torch\n    torch.manual_seed(137)\n    print(torch.randint(0, 0xFFFFFFFF, [1]).item())\n    print(torch.randint(0, 0xFFFFFFFF, [1]).item())\n    print(torch.randint(0, 0xFFFFFFFF, [1]).item())\n    print(torch.randint(0, 0xFFFFFFFF, [1]).item())\n    print(torch.randint(0, 0xFFFFFFFF, [1]).item())\n    t = torch.zeros(8);\n    t.normal_()\n    for i in range(len(t)) :\n        print(t[i].item())\n    print(torch.randint(0, 0xFFFFFFFF, [1]).item())\n    t = torch.zeros(16);\n    t.normal_()\n    for i in range(len(t)) :\n        print(t[i].item())\n    print(torch.randint(0, 0xFFFFFFFF, [1]).item())\n\nBoth output:\n\n    4053805790\n    2173880614\n    380293709\n    1237255315\n    2986595568\n    0.7947664260864258\n    1.4369317293167114\n    - 0.2292192131280899\n    0.47556325793266296\n    - 0.6334410905838013\n    - 0.5791953802108765\n    - 0.0925704762339592\n    - 0.8659197092056274\n    2186503452\n    - 1.2813878059387207\n    - 2.646395683288574\n    - 0.06569503247737885\n    0.2180829495191574\n    - 0.46536165475845337\n    - 0.33108410239219666\n    2.5485482215881348\n    0.10425379872322083\n    0.8460659980773926\n    0.9462448358535767\n    - 0.2913765013217926\n    0.34313806891441345\n    - 1.1186704635620117\n    - 0.18305328488349915\n    - 2.3153159618377686\n    0.3961987793445587\n    2756748748\n*/\n\n#ifndef RAND_H\n#define RAND_H\n\n#include <math.h>\n\n#define MERSENNE_STATE_M 397u\n#define MERSENNE_STATE_N 624u\n\n#define LMASK 0x7ffffffful\n#define UMASK 0x80000000ul\n\n// Copyright(c) Makoto Matsumoto and Takuji Nishimura\n\n// This implementation follows PyTorch so that we are numerically identical when running verification tests.\n\ntypedef struct {\n    unsigned long long seed_;\n    int left_;\n    unsigned int next_;\n    unsigned int state_[MERSENNE_STATE_N];\n    unsigned int MATRIX_A[2];\n} mt19937_state;\n\nvoid manual_seed(mt19937_state* state, unsigned int seed) {\n    state->MATRIX_A[0] = 0x0u;\n    state->MATRIX_A[1] = 0x9908b0df;\n    state->state_[0] = seed & 0xffffffff;\n    for (unsigned int j = 1; j < MERSENNE_STATE_N; j++) {\n        state->state_[j] = 1812433253 * (state->state_[j - 1] ^ (state->state_[j - 1] >> 30)) + j;\n        state->state_[j] &= 0xffffffff;\n    }\n    state->left_ = 1;\n    state->next_ = 0;\n}\n\nvoid next_state(mt19937_state* state) {\n    state->left_ = MERSENNE_STATE_N;\n    state->next_ = 0;\n    unsigned int y, j;\n    for (j = 0; j < MERSENNE_STATE_N - MERSENNE_STATE_M; j++) {\n        y = (state->state_[j] & UMASK) | (state->state_[j + 1] & LMASK);\n        state->state_[j] = state->state_[j + MERSENNE_STATE_M] ^ (y >> 1) ^ state->MATRIX_A[y & 0x1];\n    }\n    for (; j < MERSENNE_STATE_N - 1; j++) {\n        y = (state->state_[j] & UMASK) | (state->state_[j + 1] & LMASK);\n        state->state_[j] = state->state_[j + (MERSENNE_STATE_M - MERSENNE_STATE_N)] ^ (y >> 1) ^ state->MATRIX_A[y & 0x1];\n    }\n    y = (state->state_[MERSENNE_STATE_N - 1] & UMASK) | (state->state_[0] & LMASK);\n    state->state_[MERSENNE_STATE_N - 1] = state->state_[MERSENNE_STATE_M - 1] ^ (y >> 1) ^ state->MATRIX_A[y & 0x1];\n}\n\nunsigned int randint32(mt19937_state* state) {\n    if (!state) return 0;\n    if (state->MATRIX_A[0] != 0 || state->MATRIX_A[1] != 0x9908b0df) manual_seed(state, 5489); // auto-initialize\n    if (--state->left_ <= 0) {\n        next_state(state);\n    }\n    unsigned int y = state->state_[state->next_++];\n    y ^= y >> 11;\n    y ^= (y << 7) & 0x9d2c5680;\n    y ^= (y << 15) & 0xefc60000;\n    y ^= y >> 18;\n    return y;\n}\n\ninline unsigned long long randint64(mt19937_state* state) {\n    return (((unsigned long long)(randint32(state)) << 32) | randint32(state));\n}\n\ninline float randfloat32(mt19937_state* state) {\n    return (randint32(state) & ((1ull << 24) - 1)) * (1.0f / (1ull << 24));\n}\n\ninline double randfloat64(mt19937_state* state) {\n    return (randint64(state) & ((1ull << 53) - 1)) * (1.0 / (1ull << 53));\n}\n\nvoid uniform_(float* data, unsigned int numel, float from, float to, mt19937_state* state) {\n    for (unsigned int t = 0; t < numel; t++) {\n        data[t] = randfloat32(state) * (to - from) + from;\n    }\n}\n\n// Box-Muller transform: maps uniform random numbers to Gaussian distributed numbers\n// https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform\nvoid normal_fill_16(float* data, float mean, float std) {\n    #define EPSILONE 1e-12f\n    for (unsigned int t = 0; t < 8; t++) {\n        float u1 = 1 - data[t];\n        float u2 = data[t + 8];\n        float radius = sqrtf(-2 * logf(u1 + EPSILONE));\n        float theta = (float) (2.0 * M_PI * u2);\n        data[t] = (radius * cosf(theta) * std + mean);\n        data[t + 8] = (radius * sinf(theta) * std + mean);\n    }\n}\n\nvoid normal_fill(float* data, unsigned int numel, float mean, float std, mt19937_state* state) {\n    for (unsigned int t = 0; t < numel; t++) {\n        data[t] = randfloat32(state);\n    }\n    for (unsigned int i = 0; i < numel - 15; i += 16) {\n        normal_fill_16(data + i, mean, std);\n    }\n    if (numel % 16 != 0) {\n        // recompute the last 16 values\n        data = data + numel - 16;\n        for (unsigned int i = 0; i < 16; i++) {\n            data[i] = randfloat32(state);\n        }\n        normal_fill_16(data, mean, std);\n    }\n}\n\nvoid normal_(float* data, unsigned int numel, float mean, float std, mt19937_state* state) {\n    #define EPSILONE 1e-12f\n    if (numel >= 16) {\n        normal_fill(data, numel, mean, std, state);\n    }\n    else {\n        double next_double_normal_sample = 0.0; // make compiler warning happy, won't be used\n        int has_next_double_normal_sample = 0;\n        for (unsigned int  t = 0; t < numel; t++) {\n            if (has_next_double_normal_sample) {\n                data[t] = (float)(next_double_normal_sample * std + mean);\n                has_next_double_normal_sample = 0;\n                continue;\n            }\n            // for numel < 16 we draw a double (float64)\n            float u1 = (float) randfloat64(state);\n            float u2 = (float) randfloat64(state);\n            float radius = sqrtf(-2 * logf(1 - u2 + EPSILONE));\n            float theta = (float) (2.0 * M_PI * u1);\n            next_double_normal_sample = radius * sinf(theta);\n            has_next_double_normal_sample = 1;\n            data[t] = (radius * cosf(theta) * std + mean);\n        }\n    }\n}\n\nvoid init_identity_permutation(int *data, int numel) {\n    for (int i = 0; i < numel; i++) {\n        data[i] = i;\n    }\n}\n\nvoid random_permutation(int* data, int numel, mt19937_state* state) {\n    for (int i = numel - 1; i > 0; i--) {\n        // pick an index j in [0, i] with equal probability\n        int j = randint32(state) % (i + 1);\n        // swap i <-> j\n        int tmp = data[i];\n        data[i] = data[j];\n        data[j] = tmp;\n    }\n}\n\n#endif"
  },
  {
    "path": "llmc/sampler.h",
    "content": "/*\nImplements a simple Sampler, used during model inference to sample tokens.\n*/\n#ifndef SAMPLER_H\n#define SAMPLER_H\n\n#include <math.h>\n\n// Simple xorshift RNG\nunsigned int random_u32(unsigned long long *state) {\n    // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A\n    *state ^= *state >> 12;\n    *state ^= *state << 25;\n    *state ^= *state >> 27;\n    return (*state * 0x2545F4914F6CDD1Dull) >> 32;\n}\n\nfloat random_f32(unsigned long long *state) { // random float32 in [0,1)\n    return (random_u32(state) >> 8) / 16777216.0f;\n}\n\nint sample_softmax(const float* logits, int n, float coin) {\n    // sample index from logits (converted to probabilities using softmax)\n    // coin is a random number in [0, 1), usually from random_f32()\n    double norm = 0;\n    for (int i = 0; i < n; i++) {\n        norm += expf(logits[i]);\n    }\n    // instead of dividing all exp(logits), we can just multiply coin.\n    coin *= norm;\n    float cdf = 0.0f;\n    for (int i = 0; i < n; i++) {\n        cdf += expf(logits[i]);\n        if (coin < cdf) {\n            return i;\n        }\n    }\n    return n - 1; // in case of rounding errors\n}\n\n#endif"
  },
  {
    "path": "llmc/schedulers.h",
    "content": "/*\nImplements various learning rate schedulers.\n*/\n#ifndef SCHEDULERS_H\n#define SCHEDULERS_H\n\n#include <assert.h>\n#include <math.h>\n#include <string.h>\n\ntypedef struct {\n    const char* type;\n    float learning_rate;\n    int warmup_iterations;\n    int train_num_batches;\n    float final_learning_rate_frac;\n} LearningRateScheduler;\n\nvoid lr_scheduler_init(LearningRateScheduler *scheduler, const char* scheduler_type, float learning_rate, int warmup_iterations, int train_num_batches, float final_learning_rate_frac) {\n    scheduler->type = scheduler_type;\n    scheduler->learning_rate = learning_rate;\n    scheduler->warmup_iterations = warmup_iterations;\n    scheduler->train_num_batches = train_num_batches;\n    scheduler->final_learning_rate_frac = final_learning_rate_frac;\n}\n\n// cosine: warmup linearly to max LR, then cosine decay to LR * final_learning_rate_frac\nfloat get_learning_rate_cosine(LearningRateScheduler *scheduler, int step) {\n    float lr = scheduler->learning_rate;\n    if (step < scheduler->warmup_iterations) {\n        lr = scheduler->learning_rate * ((float)(step + 1)) / scheduler->warmup_iterations;\n    } else {\n        float decay_ratio = ((float)(step - scheduler->warmup_iterations)) / (scheduler->train_num_batches - scheduler->warmup_iterations);\n        assert(0.0f <= decay_ratio && decay_ratio <= 1.0f);\n        float coeff = 0.5f * (1.0f + cosf(M_PI * decay_ratio)); // coeff starts at 1 and goes to 0\n        assert(0.0f <= coeff && coeff <= 1.0f);\n        float min_lr = scheduler->learning_rate * scheduler->final_learning_rate_frac;\n        lr = min_lr + coeff * (scheduler->learning_rate - min_lr);\n    }\n    return lr;\n}\n\n// linear: warmup linearly to max LR, then decay linearly to LR * final_learning_rate_frac\nfloat get_learning_rate_linear(LearningRateScheduler *scheduler, int step) {\n    float lr = scheduler->learning_rate;\n    if (step < scheduler->warmup_iterations) {\n        lr = scheduler->learning_rate * ((float)(step + 1)) / scheduler->warmup_iterations;\n    } else {\n        float decay_ratio = ((float)(step - scheduler->warmup_iterations)) / (scheduler->train_num_batches - scheduler->warmup_iterations);\n        assert(0.0f <= decay_ratio && decay_ratio <= 1.0f);\n        float min_lr = scheduler->learning_rate * scheduler->final_learning_rate_frac;\n        lr = scheduler->learning_rate - decay_ratio * (scheduler->learning_rate - min_lr);\n    }\n    return lr;\n}\n\n// constant\nfloat get_learning_rate_constant(LearningRateScheduler *scheduler, int step) {\n    return scheduler->learning_rate;\n}\n\n// wsd schedule: warmup linearly, keep constant, last 20% decay using 1 - sqrt decay to final_frac (should be 0.0)\n// https://arxiv.org/abs/2405.18392\nfloat get_learning_rate_wsd(LearningRateScheduler *scheduler, int step) {\n    int decay_point = (int)(0.8f * scheduler->train_num_batches);\n    float max_lr = scheduler->learning_rate;\n    float lr = max_lr;\n    if (step < scheduler->warmup_iterations) {\n        float decay_ratio = ((float)(step + 1)) / scheduler->warmup_iterations;\n        lr = max_lr * decay_ratio;\n    } else if (step < decay_point) {\n        // noop, keep lr constant\n    } else {\n        float decay_ratio = ((float)(step - decay_point)) / (scheduler->train_num_batches - decay_point);\n        assert(0.0f <= decay_ratio && decay_ratio <= 1.0f);\n        float min_lr = max_lr * scheduler->final_learning_rate_frac;\n        return min_lr + (1.0f - sqrtf(decay_ratio)) * (max_lr - min_lr);\n    }\n    return lr;\n}\n\n// return the learning rate at a given step\nfloat get_learning_rate(LearningRateScheduler *scheduler, int step) {\n    float step_learning_rate;\n    if (strcmp(scheduler->type, \"cosine\") == 0) {\n        step_learning_rate = get_learning_rate_cosine(scheduler, step);\n    } else if (strcmp(scheduler->type, \"linear\") == 0) {\n        step_learning_rate = get_learning_rate_linear(scheduler, step);\n    } else if (strcmp(scheduler->type, \"constant\") == 0) {\n        step_learning_rate = get_learning_rate_constant(scheduler, step);\n    } else if (strcmp(scheduler->type, \"wsd\") == 0) {\n        step_learning_rate = get_learning_rate_wsd(scheduler, step);\n    } else {\n        fprintf(stderr, \"Unknown learning rate scheduler type: %s\\n\", scheduler->type);\n        exit(EXIT_FAILURE);\n    }\n    return step_learning_rate;\n}\n\n#endif // SCHEDULERS_H\n"
  },
  {
    "path": "llmc/tokenizer.h",
    "content": "/*\nDefines the GPT-2 Tokenizer.\nOnly supports decoding, i.e.: tokens (integers) -> strings\nThis is all we need for unconditional generation.\nIf we wanted to later prompt the model, we'd have to add decoding.\nWhich could be tricky in C because of the regex involved, to look into later.\n*/\n\n#include <stdint.h>\n#include <ctype.h>\n#include <assert.h>\n// our own utilities\n// defines fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck\n#include \"utils.h\"\n\n// ----------------------------------------------------------------------------\n\ntypedef struct {\n    uint32_t vocab_size;\n    char **token_table;\n    int init_ok;\n    int eot_token; // <|endoftext|> token id\n} Tokenizer;\n\nvoid safe_printf(const char *piece) {\n    // the tokens are raw bytes, and we we only want to print the printable ones\n    // many bytes can be various control codes, backspace, etc.\n    if (piece == NULL) { return; }\n    if (piece[0] == '\\0') { return; }\n    // handle individual byte tokens\n    // every token is asserted to be at least one byte so doing piece[1] is ok\n    if (piece[1] == '\\0') {\n        unsigned char byte_val = piece[0];\n        if (!(isprint(byte_val) || isspace(byte_val))) {\n            return; // weird byte, don't print it\n        }\n    }\n    printf(\"%s\", piece);\n}\n\nvoid tokenizer_init(Tokenizer *tokenizer, const char *filename) {\n    FILE *file = fopen(filename, \"rb\");\n    if (file == NULL) {\n        // try to be more helpful as we just added this feature, erase later\n        printf(\"---\\n\");\n        printf(\"WARNING: Failed to open the tokenizer file %s\\n\", filename);\n        printf(\"The Tokenizer is a new feature added April 14 2024.\\n\");\n        printf(\"Re-run `python train_gpt2.py` to write it\\n\");\n        printf(\"---\\n\");\n        tokenizer->init_ok = 0;\n        return;\n    }\n    // read in the header\n    uint32_t header[256];\n    freadCheck(header, sizeof(uint32_t), 256, file);\n    assert(header[0] == 20240328);\n    int version = header[1];\n    tokenizer->vocab_size = header[2];\n    if (version == 1) {\n        // version 1 didn't include the EOT token id\n        // so we assume it is 50256, the EOT in GPT-2\n        assert(tokenizer->vocab_size == 50257); // let's be defensive here\n        tokenizer->eot_token = 50256;\n    } else if (version == 2) {\n        tokenizer->eot_token = header[3];\n    } else {\n        fprintf(stderr, \"Tokenizer model file %s has bad version: %d\\n\", filename, version);\n        exit(EXIT_FAILURE);\n    }\n    // read in all the tokens\n    unsigned char length;\n    tokenizer->token_table = (char **)mallocCheck(tokenizer->vocab_size * sizeof(char *));\n    for (uint32_t i = 0; i < tokenizer->vocab_size; i++) {\n        freadCheck(&length, sizeof(unsigned char), 1, file);\n        assert(length > 0); // every token should be at least one character\n        char *token_bytes = (char *)mallocCheck(length + 1);\n        freadCheck(token_bytes, sizeof(char), length, file);\n        token_bytes[length] = '\\0';  // Add null terminator for printing\n        tokenizer->token_table[i] = token_bytes;\n    }\n    // cleanups\n    fcloseCheck(file);\n    tokenizer->init_ok = 1;\n}\n\nconst char *tokenizer_decode(Tokenizer *tokenizer, uint32_t token_id) {\n    if (tokenizer->init_ok == 0) {\n        return NULL;\n    }\n    if (token_id < tokenizer->vocab_size) {\n        return tokenizer->token_table[token_id];\n    } else {\n        printf(\"invalid token id %u!\\n\", token_id);\n        return NULL;\n    }\n}\n\nvoid tokenizer_free(Tokenizer *tokenizer) {\n    if (tokenizer->init_ok) {\n        for (uint32_t i = 0; i < tokenizer->vocab_size; i++) {\n            free(tokenizer->token_table[i]);\n        }\n        free(tokenizer->token_table);\n    }\n}\n"
  },
  {
    "path": "llmc/utils.h",
    "content": "/*\n This file contains utilities shared between the different training scripts.\n In particular, we define a series of macros xxxCheck that call the corresponding\n C standard library function and check its return code. If an error was reported,\n the program prints some debug information and exits.\n*/\n#ifndef UTILS_H\n#define UTILS_H\n\n#include <unistd.h>\n#include <string.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include <sys/stat.h>\n// implementation of dirent for Windows is in dev/unistd.h\n#ifndef _WIN32\n#include <dirent.h>\n#include <arpa/inet.h>\n#endif\n\n// ----------------------------------------------------------------------------\n// fread convenience utils, with nice handling of error checking using macros\n// simple replace fopen, fread, fclose, fseek\n// with fopenCheck, freadCheck, fcloseCheck, fseekCheck\n\nextern inline FILE *fopen_check(const char *path, const char *mode, const char *file, int line) {\n    FILE *fp = fopen(path, mode);\n    if (fp == NULL) {\n        fprintf(stderr, \"Error: Failed to open file '%s' at %s:%d\\n\", path, file, line);\n        fprintf(stderr, \"Error details:\\n\");\n        fprintf(stderr, \"  File: %s\\n\", file);\n        fprintf(stderr, \"  Line: %d\\n\", line);\n        fprintf(stderr, \"  Path: %s\\n\", path);\n        fprintf(stderr, \"  Mode: %s\\n\", mode);\n        fprintf(stderr, \"---> HINT 1: dataset files/code have moved to dev/data recently (May 20, 2024). You may have to mv them from the legacy data/ dir to dev/data/(dataset), or re-run the data preprocessing script. Refer back to the main README\\n\");\n        fprintf(stderr, \"---> HINT 2: possibly try to re-run `python train_gpt2.py`\\n\");\n        exit(EXIT_FAILURE);\n    }\n    return fp;\n}\n\n#define fopenCheck(path, mode) fopen_check(path, mode, __FILE__, __LINE__)\n\nextern inline void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) {\n    size_t result = fread(ptr, size, nmemb, stream);\n    if (result != nmemb) {\n        if (feof(stream)) {\n            fprintf(stderr, \"Error: Unexpected end of file at %s:%d\\n\", file, line);\n        } else if (ferror(stream)) {\n            fprintf(stderr, \"Error: File read error at %s:%d\\n\", file, line);\n        } else {\n            fprintf(stderr, \"Error: Partial read at %s:%d. Expected %zu elements, read %zu\\n\",\n                    file, line, nmemb, result);\n        }\n        fprintf(stderr, \"Error details:\\n\");\n        fprintf(stderr, \"  File: %s\\n\", file);\n        fprintf(stderr, \"  Line: %d\\n\", line);\n        fprintf(stderr, \"  Expected elements: %zu\\n\", nmemb);\n        fprintf(stderr, \"  Read elements: %zu\\n\", result);\n        exit(EXIT_FAILURE);\n    }\n}\n\n#define freadCheck(ptr, size, nmemb, stream) fread_check(ptr, size, nmemb, stream, __FILE__, __LINE__)\n\nextern inline void fclose_check(FILE *fp, const char *file, int line) {\n    if (fclose(fp) != 0) {\n        fprintf(stderr, \"Error: Failed to close file at %s:%d\\n\", file, line);\n        fprintf(stderr, \"Error details:\\n\");\n        fprintf(stderr, \"  File: %s\\n\", file);\n        fprintf(stderr, \"  Line: %d\\n\", line);\n        exit(EXIT_FAILURE);\n    }\n}\n\n#define fcloseCheck(fp) fclose_check(fp, __FILE__, __LINE__)\n\nextern inline void sclose_check(int sockfd, const char *file, int line) {\n    if (close(sockfd) != 0) {\n        fprintf(stderr, \"Error: Failed to close socket at %s:%d\\n\", file, line);\n        fprintf(stderr, \"Error details:\\n\");\n        fprintf(stderr, \"  File: %s\\n\", file);\n        fprintf(stderr, \"  Line: %d\\n\", line);\n        exit(EXIT_FAILURE);\n    }\n}\n\n#define scloseCheck(sockfd) sclose_check(sockfd, __FILE__, __LINE__)\n\n#ifdef _WIN32\nextern inline void closesocket_check(int sockfd, const char *file, int line) {\n    if (closesocket(sockfd) != 0) {\n        fprintf(stderr, \"Error: Failed to close socket at %s:%d\\n\", file, line);\n        fprintf(stderr, \"Error details:\\n\");\n        fprintf(stderr, \"  File: %s\\n\", file);\n        fprintf(stderr, \"  Line: %d\\n\", line);\n        exit(EXIT_FAILURE);\n    }\n}\n\n#define closesocketCheck(sockfd) closesocket_check(sockfd, __FILE__, __LINE__)\n#endif\n\nextern inline void fseek_check(FILE *fp, long off, int whence, const char *file, int line) {\n    if (fseek(fp, off, whence) != 0) {\n        fprintf(stderr, \"Error: Failed to seek in file at %s:%d\\n\", file, line);\n        fprintf(stderr, \"Error details:\\n\");\n        fprintf(stderr, \"  Offset: %ld\\n\", off);\n        fprintf(stderr, \"  Whence: %d\\n\", whence);\n        fprintf(stderr, \"  File:   %s\\n\", file);\n        fprintf(stderr, \"  Line:   %d\\n\", line);\n        exit(EXIT_FAILURE);\n    }\n}\n\n#define fseekCheck(fp, off, whence) fseek_check(fp, off, whence, __FILE__, __LINE__)\n\nextern inline void fwrite_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) {\n    size_t result = fwrite(ptr, size, nmemb, stream);\n    if (result != nmemb) {\n        if (feof(stream)) {\n            fprintf(stderr, \"Error: Unexpected end of file at %s:%d\\n\", file, line);\n        } else if (ferror(stream)) {\n            fprintf(stderr, \"Error: File write error at %s:%d\\n\", file, line);\n        } else {\n            fprintf(stderr, \"Error: Partial write at %s:%d. Expected %zu elements, wrote %zu\\n\",\n                    file, line, nmemb, result);\n        }\n        fprintf(stderr, \"Error details:\\n\");\n        fprintf(stderr, \"  File: %s\\n\", file);\n        fprintf(stderr, \"  Line: %d\\n\", line);\n        fprintf(stderr, \"  Expected elements: %zu\\n\", nmemb);\n        fprintf(stderr, \"  Written elements: %zu\\n\", result);\n        exit(EXIT_FAILURE);\n    }\n}\n\n#define fwriteCheck(ptr, size, nmemb, stream) fwrite_check(ptr, size, nmemb, stream, __FILE__, __LINE__)\n\n// ----------------------------------------------------------------------------\n// malloc error-handling wrapper util\n\nextern inline void *malloc_check(size_t size, const char *file, int line) {\n    void *ptr = malloc(size);\n    if (ptr == NULL) {\n        fprintf(stderr, \"Error: Memory allocation failed at %s:%d\\n\", file, line);\n        fprintf(stderr, \"Error details:\\n\");\n        fprintf(stderr, \"  File: %s\\n\", file);\n        fprintf(stderr, \"  Line: %d\\n\", line);\n        fprintf(stderr, \"  Size: %zu bytes\\n\", size);\n        exit(EXIT_FAILURE);\n    }\n    return ptr;\n}\n\n#define mallocCheck(size) malloc_check(size, __FILE__, __LINE__)\n\n\n// ----------------------------------------------------------------------------\n// check that all tokens are within range\nextern inline void token_check(const int* tokens, int token_count, int vocab_size, const char *file, int line) {\n    for(int i = 0; i < token_count; i++) {\n        if(!(0 <= tokens[i] && tokens[i] < vocab_size)) {\n            fprintf(stderr, \"Error: Token out of vocabulary at %s:%d\\n\", file, line);\n            fprintf(stderr, \"Error details:\\n\");\n            fprintf(stderr, \"  File: %s\\n\", file);\n            fprintf(stderr, \"  Line: %d\\n\", line);\n            fprintf(stderr, \"  Token: %d\\n\", tokens[i]);\n            fprintf(stderr, \"  Position: %d\\n\", i);\n            fprintf(stderr, \"  Vocab: %d\\n\", vocab_size);\n            exit(EXIT_FAILURE);\n        }\n    }\n}\n#define tokenCheck(tokens, count, vocab) token_check(tokens, count, vocab, __FILE__, __LINE__)\n\n// ----------------------------------------------------------------------------\n// I/O ops\n\nextern inline void create_dir_if_not_exists(const char *dir) {\n    if (dir == NULL) { return; }\n    struct stat st = {0};\n    if (stat(dir, &st) == -1) {\n        if (mkdir(dir, 0700) == -1) {\n            printf(\"ERROR: could not create directory: %s\\n\", dir);\n            exit(EXIT_FAILURE);\n        }\n        printf(\"created directory: %s\\n\", dir);\n    }\n}\n\nextern inline int find_max_step(const char* output_log_dir) {\n    // find the DONE file in the log dir with highest step count\n    if (output_log_dir == NULL) { return -1; }\n    DIR* dir;\n    struct dirent* entry;\n    int max_step = -1;\n    dir = opendir(output_log_dir);\n    if (dir == NULL) { return -1; }\n    while ((entry = readdir(dir)) != NULL) {\n        if (strncmp(entry->d_name, \"DONE_\", 5) == 0) {\n            int step = atoi(entry->d_name + 5);\n            if (step > max_step) {\n                max_step = step;\n            }\n        }\n    }\n    closedir(dir);\n    return max_step;\n}\n\nextern inline int ends_with_bin(const char* str) {\n    // checks if str ends with \".bin\". could be generalized in the future.\n    if (str == NULL) { return 0; }\n    size_t len = strlen(str);\n    const char* suffix = \".bin\";\n    size_t suffix_len = strlen(suffix);\n    if (len < suffix_len) { return 0; }\n    int suffix_matches = strncmp(str + len - suffix_len, suffix, suffix_len) == 0;\n    return suffix_matches;\n}\n\n#endif"
  },
  {
    "path": "llmc/zero.cuh",
    "content": "/*\nUtilities for ZeRO sharding\n*/\n\n#ifndef LLMC_ZERO_CUH\n#define LLMC_ZERO_CUH\n\n#include <cuda_runtime_api.h>\n#include <stdint.h>\n#include <stdlib.h>\n#include <stdio.h>\n#include <stddef.h>\n\n#ifdef MULTI_GPU\n#include <nccl.h>\n#ifdef USE_MPI\n#include <mpi.h>\n#endif\n#endif\n\n// defines: fcloseCheck, fwriteCheck, scloseCheck, sclosesocketCheck\n#include \"utils.h\"\n\n// ----------------------------------------------------------------------------\n// Multi-GPU related\n#ifdef MULTI_GPU\n\n#if defined(ENABLE_FP32)\nconst ncclDataType_t ncclFloatX = ncclFloat;\n#elif defined(ENABLE_FP16)\nconst ncclDataType_t ncclFloatX = ncclHalf;\n#else // Default to bfloat16\nconst ncclDataType_t ncclFloatX = ncclBfloat16;\n#endif\n\nvoid nccl_check(ncclResult_t status, const char *file, int line) {\n    if (status != ncclSuccess) {\n        printf(\"[NCCL ERROR] at file %s:%d:\\n%s\\n\", file, line, ncclGetErrorString(status));\n        exit(EXIT_FAILURE);\n    }\n}\n#define ncclCheck(err) (nccl_check(err, __FILE__, __LINE__))\n\n#ifdef USE_MPI\nvoid mpi_check(int status, const char *file, int line) {\n    if (status != MPI_SUCCESS) {\n        char mpi_error[4096];\n        int mpi_error_len = 0;\n        assert(MPI_Error_string(status, &mpi_error[0], &mpi_error_len) == MPI_SUCCESS);\n        printf(\"[MPI ERROR] at file %s:%d:\\n%.*s\\n\", file, line, mpi_error_len, mpi_error);\n        exit(EXIT_FAILURE);\n    }\n}\n#define mpiCheck(err) (mpi_check(err, __FILE__, __LINE__))\n#endif\n\n#endif // MULTI_GPU\n\n// ----------------------------------------------------------------------------\n// Parameters specific to training on multiple GPUs.\ntypedef struct {\n    int process_rank;      // Rank of this process among all processes. 0 if no multi-GPU.\n    int num_processes;     // Total number of processes. 1 if no multi-GPU.\n    int local_device_idx;  // This process GPU index on current machine. 0 if no multi-GPU.\n\n    // Zero Redundancy Optimizer stage - https://fairscale.readthedocs.io/en/stable/deep_dive/oss_sdp_fsdp.html\n    // 0-Disabled\n    // 1-Optimizer State Sharding (OSS)\n    // 2-Optimizer + Gradient State Sharding (SDP)\n    // 3-Optimizer + Gradient + Horizontal Model Sharding (FSDP)\n    int zero_stage;\n    size_t shard_num_parameters;\n#ifdef MULTI_GPU\n    ncclComm_t nccl_comm;       // NCCL communication primitive, used for collective multi-GPU work.\n    cudaStream_t nccl_stream;   // CUDA Stream to perform NCCL operations.\n    cudaEvent_t compute_nccl_sync; // Event used to synchronize NCCL with the compute\n    float* unified_buffer;\n#endif\n} MultiGpuConfig;\n\n// one global variable to hold the multi-GPU configuration for this process\n// inline, so we can include this header multiple times without getting multiple definitions\ninline MultiGpuConfig multi_gpu_config;\n\n#ifdef MULTI_GPU\n\n#ifdef _WIN32\nvoid send_nccl_id_to_clients_windows(ncclUniqueId *nccl_id, SOCKET client_sockets[], int num_clients) {\n    for (int i = 0; i < num_clients; ++i) {\n        if (send(client_sockets[i], (const char *)nccl_id, sizeof(*nccl_id), 0) == SOCKET_ERROR) {\n            printf(\"Failed to send nccl_id\");\n            WSACleanup();\n            exit(EXIT_FAILURE);\n        }\n        closesocketCheck(client_sockets[i]);\n    }\n}\n#else\nvoid send_nccl_id_to_clients(ncclUniqueId *nccl_id, int client_sockets[], int num_clients) {\n    for (int i = 0; i < num_clients; ++i) {\n        if (send(client_sockets[i], nccl_id, sizeof(*nccl_id), 0) == -1) {\n            printf(\"Failed to send nccl_id\");\n            exit(EXIT_FAILURE);\n        }\n        scloseCheck(client_sockets[i]);\n    }\n}\n#endif\n\n#ifdef _WIN32\n// Same as get_nccl_id_via_tcp but for Windows\nncclUniqueId get_nccl_id_via_tcp_windows(MultiGpuConfig* result, const char* server_ip) {\n    ncclUniqueId nccl_id;\n\n    int SERVER_PORT = 12345;  // hardcoded an arbitrary port number between 1024 and 49151 (registered ports)\n    WSADATA wsaData;\n    if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {\n        printf(\"WSAStartup failed\");\n        exit(EXIT_FAILURE);\n    }\n\n    if (result->process_rank == 0) {\n        ncclCheck(ncclGetUniqueId(&nccl_id));\n\n        int MAX_CLIENTS = result->num_processes - 1;\n        SOCKET client_sockets[MAX_CLIENTS];\n        int num_clients = 0;\n        SOCKET server_socket, new_socket;\n        struct sockaddr_in address;\n        int addrlen = sizeof(address);\n\n        // Step 1) create a server TCP socket\n        if ((server_socket = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET) {\n            printf(\"Socket failed\");\n            WSACleanup();\n            exit(EXIT_FAILURE);\n        }\n\n        // Step 2) set the server address and port\n        address.sin_family = AF_INET;  // IPv4\n        address.sin_addr.s_addr = inet_addr(server_ip);\n        address.sin_port = htons(SERVER_PORT);\n\n        // Step 3) bind the socket to the address and port\n        if (bind(server_socket, (struct sockaddr *)&address, sizeof(address)) == SOCKET_ERROR) {\n            printf(\"Bind failed\");\n            closesocketCheck(server_socket);\n            WSACleanup();\n            exit(EXIT_FAILURE);\n        }\n\n        // Step 4) MAX_CLIENTS specifies the maximum number of clients that can be queued for this server\n        if (listen(server_socket, MAX_CLIENTS) == SOCKET_ERROR) {\n            printf(\"Listen failed\");\n            closesocketCheck(server_socket);\n            WSACleanup();\n            exit(EXIT_FAILURE);\n        }\n\n        // Step 5) accept connections from clients\n        printf(\"Waiting for clients to connect...\\n\");\n        while (num_clients < MAX_CLIENTS) {\n            if ((new_socket = accept(server_socket, (struct sockaddr *)&address, &addrlen)) == INVALID_SOCKET) {\n                printf(\"Accept failed\");\n                closesocketCheck(server_socket);\n                WSACleanup();\n                exit(EXIT_FAILURE);\n            }\n            client_sockets[num_clients++] = new_socket;\n            printf(\"Client %d connected\\n\", num_clients);\n        }\n\n        // Step 6) send the NCCL ID to all clients\n        send_nccl_id_to_clients_windows(&nccl_id, client_sockets, num_clients);\n        printf(\"NCCL ID sent to all clients\\n\");\n\n        closesocketCheck(server_socket);\n    } else {\n        int num_connection_attempts = 5;\n        int time_to_sleep = 2;\n        SOCKET client_socket;\n        struct sockaddr_in serv_addr;\n\n        // Step 1) create a client TCP socket\n        if ((client_socket = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET) {\n            printf(\"Socket creation error\");\n            WSACleanup();\n            exit(EXIT_FAILURE);\n        }\n\n        // Step 2) set the server address and port\n        serv_addr.sin_family = AF_INET;\n        serv_addr.sin_port = htons(SERVER_PORT);\n        if (inet_pton(AF_INET, server_ip, &serv_addr.sin_addr) <= 0) {\n            printf(\"Invalid address or address not supported\");\n            closesocketCheck(client_socket);\n            WSACleanup();\n            exit(EXIT_FAILURE);\n        }\n\n        // Step 3) Try to connect to the server - retry up to `num_connection_attempts` times if the connection fails\n        while (connect(client_socket, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) == SOCKET_ERROR) {\n            printf(\"%d Connection failed, retrying in %d seconds\\n\", result->process_rank, time_to_sleep);\n            if (--num_connection_attempts == 0) {\n                printf(\"Failed to connect to the server\\n\");\n                closesocketCheck(client_socket);\n                WSACleanup();\n                exit(EXIT_FAILURE);\n            }\n            Sleep(time_to_sleep * 1000);\n        }\n\n        // Step 4) receive the NCCL ID from the server\n        if (recv(client_socket, (char *)&nccl_id, sizeof(nccl_id), 0) <= 0) {\n            printf(\"Failed to receive nccl_id\");\n            closesocketCheck(client_socket);\n            WSACleanup();\n            exit(EXIT_FAILURE);\n        }\n\n        printf(\"Received NCCL ID\\n\");\n        closesocketCheck(client_socket);\n    }\n\n    WSACleanup();\n    return nccl_id;\n}\n#else\nncclUniqueId get_nccl_id_via_tcp(MultiGpuConfig* result, const char* server_ip) {\n    ncclUniqueId nccl_id;\n\n    int SERVER_PORT = 12345;  // hardcoded an arbitrary port number between 1024 and 49151 (registered ports)\n    if (result->process_rank == 0) {\n        ncclCheck(ncclGetUniqueId(&nccl_id));\n\n        int MAX_CLIENTS = result->num_processes - 1;\n        int client_sockets[MAX_CLIENTS];\n        int num_clients = 0;\n        int server_socket, new_socket;\n        struct sockaddr_in address;\n        int addrlen = sizeof(address);\n        int opt = 1;\n\n        // Step 1) create a server TCP socket\n        if ((server_socket = socket(AF_INET, SOCK_STREAM, 0)) < 0) {\n            printf(\"Socket failed\");\n            exit(EXIT_FAILURE);\n        }\n\n        // Step 2) set socket options\n        // SOL_SOCKET - means that option is configured at socket level\n        // SO_REUSEADDR - allows to bind to an address which is in a TIME_WAIT state (already used by another socket) - useful when restarting the server\n        // SO_REUSEPORT - allows to bind to the same port multiple times\n        if (setsockopt(server_socket, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)) < 0) {\n            printf(\"Setsockopt failed\");\n            exit(EXIT_FAILURE);\n        }\n\n        // Step 3) set the server address and port\n        address.sin_family = AF_INET;  // IPv4\n        address.sin_addr.s_addr = inet_addr(server_ip); // alternatively use INADDR_ANY to bind to all interfaces, currently we only allow ethernet\n        address.sin_port = htons(SERVER_PORT);\n\n        // Step 4) bind the socket to the address and port\n        if (bind(server_socket, (struct sockaddr *)&address, sizeof(address)) < 0) {\n            printf(\"Bind failed\");\n            exit(EXIT_FAILURE);\n        }\n\n        // Step 5) MAX_CLIENTS specifies the maximum number of clients that can be queued for this server\n        if (listen(server_socket, MAX_CLIENTS) < 0) {\n            printf(\"Listen failed\");\n            exit(EXIT_FAILURE);\n        }\n\n        // Step 6) accept connections from clients\n        printf(\"Waiting for clients to connect...\\n\");\n        while (num_clients < MAX_CLIENTS) {\n            if ((new_socket = accept(server_socket, (struct sockaddr *)&address, (socklen_t*)&addrlen)) < 0) {\n                printf(\"Accept failed\");\n                exit(EXIT_FAILURE);\n            }\n            client_sockets[num_clients++] = new_socket;\n            printf(\"Client %d connected\\n\", num_clients);\n        }\n\n        // Step 7) send the NCCL ID to all clients\n        send_nccl_id_to_clients(&nccl_id, client_sockets, num_clients);\n        printf(\"NCCL ID sent to all clients\\n\");\n\n        scloseCheck(server_socket);\n    } else {\n        int num_connection_attempts = 5;\n        int time_to_sleep = 2;\n        int client_socket;\n        struct sockaddr_in serv_addr;\n\n        // Step 1) create a client TCP socket\n        if ((client_socket = socket(AF_INET, SOCK_STREAM, 0)) < 0) {\n            printf(\"Socket creation error\");\n            exit(EXIT_FAILURE);\n        }\n\n        // Step 2) set the server address and port\n        serv_addr.sin_family = AF_INET;\n        serv_addr.sin_port = htons(SERVER_PORT);\n        if (inet_pton(AF_INET, server_ip, &serv_addr.sin_addr) <= 0) {\n            printf(\"Invalid address or address not supported\");\n            exit(EXIT_FAILURE);\n        }\n\n        // Step 3) Try to connect to the server - retry up to `num_connection_attempts` times if the connection fails\n        while (connect(client_socket, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0) {\n            printf(\"%d Connection failed, retrying in %d seconds\\n\", result->process_rank, time_to_sleep);\n            if (--num_connection_attempts == 0) {\n                printf(\"Failed to connect to the server\\n\");\n                exit(EXIT_FAILURE);\n            }\n            sleep(time_to_sleep);\n        }\n\n        // Step 4) receive the NCCL ID from the server\n        if (recv(client_socket, &nccl_id, sizeof(nccl_id), 0) <= 0) {\n            printf(\"Failed to receive nccl_id\");\n            exit(EXIT_FAILURE);\n        }\n\n        printf(\"Received NCCL ID\\n\");\n        scloseCheck(client_socket);\n    }\n\n    return nccl_id;\n}\n#endif\n\nncclUniqueId get_nccl_id_via_fs(MultiGpuConfig* result, char* fs_path) {\n    // Works assuming that the filesystem is shared among all processes\n    ncclUniqueId nccl_id;\n    FILE* idFile;\n    static char filename[1024];\n    snprintf(filename, sizeof(filename), \"%s/ncclUniqueId.sync\", fs_path);\n\n    if (result->process_rank != 0) {  // client processse should wait for the server to write to the file\n        // This is a naive and not 100% robust way to synchronize the processes but it should work almost always\n        sleep(2);\n    }\n\n    if (result->process_rank == 0) {\n        ncclCheck(ncclGetUniqueId(&nccl_id));\n        idFile = fopen(filename, \"wb\");\n        assert(idFile != NULL);\n        fwriteCheck(&nccl_id, sizeof(nccl_id), 1, idFile);\n        fcloseCheck(idFile);\n    } else {\n        // Other ranks wait until the file is available and read the unique ID\n        do {\n            sleep(1);  // 1 second\n            idFile = fopen(filename, \"rb\");\n            if (idFile != NULL) break;\n        } while (idFile == NULL);\n        freadCheck(&nccl_id, sizeof(nccl_id), 1, idFile);\n        fcloseCheck(idFile);\n    }\n\n    return nccl_id;\n}\n\n#ifdef USE_MPI\n// Determine which GPU this process should use.\n// Processes on the same machines use different GPU indicies. Processes on other machines don't.\n// Copied from NCCL examples: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/examples.html#example-2-one-device-per-process-or-thread\nint multi_gpu_get_local_device_idx(int process_rank, int num_processes) {\n    char hostname[1024];\n    hostname[1023] = '\\0';\n    // All processes on the same machine will share the same hostname.\n    gethostname(hostname, 1023);\n    for (int i=0; i < 1024; i++) {\n        if (hostname[i] == '.') {\n            hostname[i] = '\\0';\n            break;\n        }\n    }\n    uint64_t hostname_hash = 5381u;\n    for (int c = 0; hostname[c] != '\\0'; c++){ hostname_hash = ((hostname_hash << 5u) + hostname_hash) ^ hostname[c]; }\n\n    // Distribute all hostname hashes to all processes.\n    uint64_t* all_hostsname_hashes = (uint64_t*)malloc(num_processes * sizeof(uint64_t));\n    all_hostsname_hashes[process_rank] = hostname_hash;\n    mpiCheck(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_hostsname_hashes, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD));\n\n    // Identify which GPU we need to use.\n    int local_device_idx = 0;\n    for (int current_process = 0; current_process < num_processes; ++current_process) {\n        if (current_process == process_rank) {\n        // Found my gpu, local_device_idx now has my target GPU index.\n        break;\n        }\n        if (all_hostsname_hashes[current_process] == all_hostsname_hashes[process_rank]) {\n        // This process ID runs on the same machine, but it's not me, skip this GPU\n        local_device_idx++;\n        }\n    }\n\n    free(all_hostsname_hashes);\n    return local_device_idx;\n}\n#endif\n\n#endif\n\nMultiGpuConfig multi_gpu_config_init(int num_processes, int process_rank, int gpus_per_node, char* server_ip, char* fs_path, char* init_method) {\n#ifdef MULTI_GPU\n    MultiGpuConfig result;\n    ncclUniqueId nccl_id;\n    // Get nccl_id using MPI, TCP, or FS (file system synchronization) methods\n    // On newer slurm versions (slurm-wlm package) PMIx is disabled so we can not use MPI for NCCL init in multi node setup\n    if (strcmp(init_method, \"mpi\") == 0) {\n        #ifdef USE_MPI\n        mpiCheck(MPI_Init(NULL, NULL));\n        mpiCheck(MPI_Comm_rank(MPI_COMM_WORLD, &result.process_rank));\n        mpiCheck(MPI_Comm_size(MPI_COMM_WORLD, &result.num_processes));\n        result.local_device_idx = multi_gpu_get_local_device_idx(result.process_rank, result.num_processes);\n        if (result.process_rank == 0) {\n            ncclCheck(ncclGetUniqueId(&nccl_id));\n        }\n        mpiCheck(MPI_Bcast(&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, MPI_COMM_WORLD));\n        #else\n        printf(\"MPI support is disabled. Please enable MPI support to use MPI-based NCCL-init method.\\n\");\n        exit(EXIT_FAILURE);\n        #endif\n    } else {\n        result.process_rank = process_rank;\n        result.num_processes = num_processes;\n        result.local_device_idx = process_rank % gpus_per_node;\n        if (strcmp(init_method, \"tcp\") == 0) {\n            #ifdef _WIN32\n            nccl_id = get_nccl_id_via_tcp_windows(&result, server_ip);\n            #else\n            nccl_id = get_nccl_id_via_tcp(&result, server_ip);\n            #endif\n        } else if (strcmp(init_method, \"fs\") == 0) {\n            nccl_id = get_nccl_id_via_fs(&result, fs_path);\n        } else {\n            printf(\"Invalid NCCL-init method\\n\");\n            exit(EXIT_FAILURE);\n        }\n    }\n    cudaCheck(cudaSetDevice(result.local_device_idx));\n    ncclCheck(ncclCommInitRank(&result.nccl_comm, result.num_processes, nccl_id, result.process_rank));\n    cudaCheck(cudaStreamCreate(&result.nccl_stream));\n    // event without timing for maximum performance\n    cudaCheck(cudaEventCreate(&result.compute_nccl_sync, cudaEventDisableTiming));\n    nvtxNameCudaStreamA(result.nccl_stream, \"nccl stream\");\n    nvtxNameCudaEventA(result.compute_nccl_sync, \"nccl compute sync\");\n    cudaCheck(cudaMallocManaged(&result.unified_buffer, sizeof(float)));\n    return result;\n#else\n    printf(\"Multi-GPU support is disabled. Using a single GPU.\\n\");\n    cudaCheck(cudaSetDevice(0));\n    MultiGpuConfig result;\n    result.process_rank = 0;\n    result.num_processes = 1;\n    result.local_device_idx = 0;\n    return result;\n#endif\n}\n\nvoid multi_gpu_config_free(MultiGpuConfig* config) {\n#ifdef MULTI_GPU\n    ncclCheck(ncclCommDestroy(config->nccl_comm));\n    cudaCheck(cudaStreamDestroy(config->nccl_stream));\n    cudaCheck(cudaEventDestroy(config->compute_nccl_sync));\n    cudaCheck(cudaFree(config->unified_buffer));\n    #ifdef USE_MPI\n    mpiCheck(MPI_Finalize());\n    #endif\n#endif\n}\n\nvoid multi_gpu_barrier(const MultiGpuConfig* config) {\n#ifdef MULTI_GPU\n    if (config->num_processes > 1) {\n        ncclCheck(ncclAllReduce(config->unified_buffer, config->unified_buffer, sizeof(float), ncclFloat, ncclSum, config->nccl_comm, config->nccl_stream));\n    }\n    cudaCheck(cudaDeviceSynchronize());\n#endif\n}\n\n// Offset and size of a tensor shard\ntypedef struct {\n    ptrdiff_t offset;\n    size_t size;\n} ShardInfo;\n\n// Get info about sharding for a tensor of elements many numbers\nShardInfo multi_gpu_get_shard_offset(size_t elements, const MultiGpuConfig* config, int shard_at_stage) {\n    const int nproc = config->num_processes;\n    if(config->zero_stage >= shard_at_stage) {\n        if (elements % nproc != 0) {\n            fprintf(stderr, \"Number of elements %zu must be a multiple of the number of processes %d\\n\", elements, nproc);\n            exit(EXIT_FAILURE);\n        }\n        return {(ptrdiff_t) (config->process_rank * (elements / nproc)), elements / nproc};\n    } else {\n        return {0, elements};\n    }\n}\n\n// Block NCCL stream until computations on compute_stream are done, then aggregate multiple pointers in an NCCL group.\n// This can work either as an all-reduce (i.e., no ZeRo), or a reduce-scatter (ZeRO 1).\n// The awkward `(&pointers)[N]` syntax ensures we are capturing the parameters as sized arrays, so that it becomes impossible\n// to call this function if pointers and pointers_sizes do not match.\ntemplate<int N>\nvoid multi_gpu_async_reduce_gradient(\n        floatX* const (&pointers)[N], const size_t (&pointers_sizes)[N],\n        MultiGpuConfig* config, cudaStream_t compute_stream) {\n    if (config->num_processes == 1) {\n        return; // no multi-GPU, just exit.\n    }\n\n#ifdef MULTI_GPU\n    NVTX_RANGE_FN();\n    // mark an event on the compute stream, and immediately wait on this in the nccl stream\n    // this means that the nccl stream won't start executing before all compute kernels that\n    // have been submitted before this point have finished.\n    // by using an event instead of cudaSyncStream, we avoid having to synchronize the host, and\n    // can enqueue new work to the GPU right away.\n    cudaCheck(cudaEventRecord(config->compute_nccl_sync, compute_stream));\n    cudaCheck(cudaStreamWaitEvent(config->nccl_stream, config->compute_nccl_sync));\n    ncclCheck(ncclGroupStart()); // NCCL group: aggregate all pointers in a single NCCL GPU kernel.\n    for (int i = 0; i < N; ++i) {\n        if(config->zero_stage == 0) {\n            ncclCheck(ncclAllReduce(\n                    pointers[i], pointers[i],\n                    pointers_sizes[i],\n                    ncclFloatX, ncclAvg,\n                    config->nccl_comm, config->nccl_stream\n            ));\n        } else if(config->zero_stage == 1) {\n            assert(pointers_sizes[i] % config->num_processes == 0);\n            size_t shard_size = pointers_sizes[i] / config->num_processes;\n            ptrdiff_t shard_offset = (ptrdiff_t)shard_size * config->process_rank;\n            ncclCheck(ncclReduceScatter(\n                    pointers[i], pointers[i] + shard_offset,\n                    shard_size,\n                    ncclFloatX, ncclAvg,\n                    config->nccl_comm, config->nccl_stream\n            ));\n        }\n    }\n    ncclCheck(ncclGroupEnd());\n#endif\n}\n\n// convenience macro that only prints if the rank of process is zero\n#define printf0(...) if (::multi_gpu_config.process_rank == 0) { printf(__VA_ARGS__); }\n\nvoid set_zero_configs(MultiGpuConfig* config, int zero_stage, size_t total_parameters) {\n    config->zero_stage = 0;\n    config->shard_num_parameters = total_parameters;\n    // Check the Zero Stage and define sharding parameters\n    if (zero_stage == 0) {\n        printf0(\"| Zero Optimization is disabled                                              |\\n\");\n    }\n    else if (zero_stage == 1) {\n        if (total_parameters % config->num_processes != 0) {\n            printf0(\"| Zero Optimization is disabled, Can't equally partition parameters          |\\n\");\n            config->zero_stage = 0;\n        }\n        else {\n            config->zero_stage = 1;\n            config->shard_num_parameters = total_parameters / config->num_processes;\n        }\n    }\n    else{\n        printf0(\"| Disabling Zero Optimization, Zero Stage2 and Stage3 are not yet supported  |\\n\");\n        config->zero_stage = 0;\n    }\n}\n\n// Compute sum of a single CPU value across all GPU processes. No-op when multi-GPU is disabled.\nfloat multi_gpu_cpu_float_sum(float value, MultiGpuConfig* config) {\n#ifdef MULTI_GPU\n    if (config->num_processes == 1) return value;\n\n    float* unified_buffer = config->unified_buffer;\n    *unified_buffer = value;\n    ncclCheck(ncclAllReduce(unified_buffer, unified_buffer, sizeof(float), ncclFloat, ncclSum, config->nccl_comm, config->nccl_stream));\n    cudaCheck(cudaDeviceSynchronize());\n    return *unified_buffer;\n#else\n    return value;\n#endif\n}\n\n#endif\n\n"
  },
  {
    "path": "profile_gpt2.cu",
    "content": "/*\nThis code is a convenience tool for profiling the CUDA kernels in the training\nloop of train_gpt2.cu. Compile:\n\nmake profile_gpt2cu NO_MULTI_GPU=1\n\nAnd then e.g. use ncu from NVIDIA. The CLI docs for example:\nhttps://docs.nvidia.com/nsight-compute/NsightComputeCli/\n\nTLDR run like:\n\nsudo ncu --set full --import-source yes -o profile -f ./profile_gpt2cu\n\nThis:\n- `--set full` means we'll collect A LOT of metrics. take out for less\n- `--import-source yes` means we'll get the source code in the profile\n- `-o profile` writes the results into file profile.ncu-rep\n- `-f` forces overwrite of the profile.ncu-rep file\n- `./profile_gpt2cu` is the executable we want to profile\n\nThis writes results into profile.ncu-rep output file.\nYou can open this up in NVIDIA Nsight Compute UI.\nFor example, I have NVIDIA Nsight Compute installed on my Mac, and I rsync\nthe profile.ncu-rep from a cloud box to local to pretty view.\n*/\n\n#define TESTING\n#include \"train_gpt2.cu\"\n\nint main(int argc, char *argv[]) {\n    char nccl_init_method[256] = \"mpi\";  // \"tcp\" or \"fs\" or \"mpi\"\n    int num_processes = -1;  // doesn't matter when using MPI\n    int process_rank = -1;  // doesn't matter when using MPI\n    int gpus_per_node = -1;  // doesn't matter when using MPI\n    char server_ip[256] = \"\";  // doesn't matter when using MPI\n    char fs_path[256] = \"\";  // doesn't matter when using MPI\n    multi_gpu_config = multi_gpu_config_init(num_processes, process_rank, gpus_per_node, server_ip, fs_path, nccl_init_method);\n    common_start(true, true);\n\n    // build the GPT-2 model from a checkpoint\n    GPT2 model;\n    gpt2_init_common(&model);\n    gpt2_build_from_checkpoint(&model, \"gpt2_124M_bf16.bin\");\n\n    int B = 24; // if program OOMs decrease this number, e.g. all the way down to 4 or etc\n    int T = 1024; // if even that OOMs move on to this one. keep them nice and powers of 2\n    printf(\"batch size: %d\\n\", B);\n    printf(\"sequence length: %d\\n\", T);\n\n    int* x = (int*)mallocCheck(B * T * sizeof(int));\n    int* y = (int*)mallocCheck(B * T * sizeof(int));\n    for(int  i = 0; i < B  * T; ++i) {\n        x[i] = i % model.config.vocab_size;\n        y[i] = i % model.config.vocab_size;\n    }\n\n    // override number of layers to 1 because all layers repeat the same kernels, only profile once\n    model.config.num_layers = 1;\n    set_zero_configs(&multi_gpu_config, 0, model.num_parameters);\n\n    gpt2_allocate_state(&model, B, T);\n    // do a training step\n    gpt2_forward(&model, x, B, T);\n    gpt2_backward_and_reduce(&model, x, y, 1, 0);\n    float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config);\n    float grad_scale = (grad_norm > 1.0f) ? 1.0f / grad_norm : 1.0f;\n    gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, grad_scale, 1, &multi_gpu_config);\n    cudaCheck(cudaDeviceSynchronize()); // finish all CUDA work to get correct precise timings\n\n    // free\n    gpt2_free(&model);\n    common_free(model);\n    return 0;\n}\n"
  },
  {
    "path": "profile_gpt2cu.py",
    "content": "# runs profiling with ncu, generates a `profile.ncu-rep` for viewing with NSight Compute, and prints out\n# basic kernel stats.\n# Note: If you run into errors because of missing access rights to performance counters, try\n# https://developer.nvidia.com/nvidia-development-tools-solutions-err_nvgpuctrperm-permission-issue-performance-counters#SolnAdminTag\n\nimport subprocess\nimport csv\nfrom collections import defaultdict\nimport shutil\n\n# find ncu: Is it on PATH?\nNCU = shutil.which(\"ncu\")\n# otherwise, guess a standard location\nif NCU is None:\n    NCU = \"/usr/local/cuda/bin/ncu\"\n\n# build the executable\nsubprocess.check_call([\"make\", \"profile_gpt2cu\", \"NO_MULTI_GPU=1\", \"USE_CUDNN=1\"])\n\n# try to see if profiling is allowed for non-root:\noptions = subprocess.check_output([\"modprobe\", \"-c\", \"nvidia\"], text=True)\ncan_profile = len([l for l in options.splitlines() if \"NVreg_RestrictProfilingToAdminUsers=0\" in l]) != 0\n\n# record metrics\n# --full and --import-source are entirely superfluous for this script, but you might want to\n# manually inspect `profile.ncu-rep`, so we keep it here\ncmd = [NCU, \"--set\", \"full\", \"--import-source\", \"yes\", \"-o\", \"profile\", \"-f\", \"./profile_gpt2cu\"]\n# do we need to run under sudo\nif not can_profile:\n    print(\"NVreg_RestrictProfilingToAdminUsers=1, running with sudo\")\n    cmd = [\"sudo\"] + cmd\nsubprocess.check_call(cmd)\n\n# generate csv\n# https://forums.developer.nvidia.com/t/converting-nsys-rep-file-into-a-csv-file-with-formatting-like-the-summary-page-in-ncu-gui/231717/3\nmetrics = [\n    \"gpu__time_duration.sum\",                   # total time\n    \"dram__bytes_read.sum\",                     # DRAM reads\n    \"dram__bytes_write.sum\",                    # DRAM writes\n    \"lts__t_sectors_srcunit_tex_op_read.sum\",   # L2 reads (sectors -- 32B)\n    \"lts__t_sectors_srcunit_tex_op_write.sum\",  # L2 writes (sectors -- 32B)\n    \"sm__pipe_tensor_op_hmma_cycles_active.avg.pct_of_peak_sustained_active\", # % of peak tensor core utilization\n    \"smsp__inst_executed.sum\",                  # instructions\n]\ncmd = [NCU, \"-i\", \"profile.ncu-rep\", \"--csv\", \"--page\", \"raw\", \"--metrics\", \",\".join(metrics)]\nresult = subprocess.check_output(cmd, text=True).strip()\n\nreader = csv.reader(result.splitlines(keepends=True))\n\n# model config\nCLS_START = -1\nCLS_NUM = 6\nN_LAYERS = 12\n\nsummaries = defaultdict(lambda: 0.0)\ncounts = defaultdict(lambda: 0)\npasses = defaultdict(lambda: 0.0)\ntotal = defaultdict(lambda: 0.0)\nno_cutlass = 0.0\nCC = \"\"\nphase = \"fwd\"\n\nkernel_profile_data = list(enumerate(reader))\n\nfor rid, row in kernel_profile_data:\n    if rid <= 2:\n        continue\n    kernel = row[4]\n    kid = rid - 2\n    if \"fused_classifier\" in kernel:\n        #  classifier: layernorm -> matmul -> fused -> bw matmul (x2) -> bw layernorm\n        CLS_START = kid - 2\n\nassert CLS_START != -1\n\n# Check every kernel to find the maximum DRAM bandwidth and Tensor Core utilisation values\nmax_dram_bw = 0.0\nmax_tensor = 0.0\nfor rid, row in kernel_profile_data:\n    if rid <= 2:\n        continue\n    time = float(row[13])\n    read = float(row[11])\n    write = float(row[12])\n    tensor = float(row[16])\n    dram_bw = (read + write) / (time / 1000.0)\n    max_dram_bw = max(max_dram_bw, dram_bw)\n    max_tensor = max(max_tensor, tensor)\n\n# round the maximum tensor core utilisation to 50% or 100%\n# consumer GPUs can only achieve 50% of peak tensor throughput on this counter\n# and for GPUs without tensor cores, we set the value to 50% to avoid division by zero\nmax_tensor = (max_tensor > 50.0) and 100.0 or 50.0\n\nprint()\nprint(\"Kernel calls:\")\nfor rid, row in kernel_profile_data:\n    if rid == 0:\n        #  headings\n        print(  f\"id pass    {'name':<40} {'time':>8} {'RAM BW':>8} {'tensor':>8} {'RAM rd':>8} {'RAM wt':>8} {'L2 rd':>8} {'L2 wt':>8} {'inst':>8}\")\n        continue\n    if rid == 1:\n        # units\n        units = f\"           {'':<40} {'ms':>8} {'GB/s':>8} {'core %':>8} {'GiB':>8} {'GiB':>8} {'GiB':>8} {'GiB':>8} {'MInst':>8}\"\n        print(units)\n        print(\".\" * len(units))\n        continue\n    if rid == 2:\n        CC = row[10]\n\n    # actual data\n    kernel = row[4]\n    time = float(row[13])\n    read = float(row[11])\n    write = float(row[12])\n    l2_read = float(row[14])\n    l2_write = float(row[15])\n    tensor = float(row[16])\n    inst = float(row[17]) / 1e6\n    dram_bw = (read + write) / (time / 1000.0)\n\n    kid = rid - 2\n\n    multiplier = 1\n    if \"encoder\" in kernel:\n        pass_name = \"enc\"\n        if phase == \"bwd\":\n            phase = \"bwd-enc\"\n    elif CLS_START <= kid < CLS_START + CLS_NUM:\n        # the classifier part, counts only once\n        pass_name = \"cls\"\n        phase = \"bwd\"\n    elif \"adamw\" in kernel or \"global_norm\" in kernel or \"copy_and_cast\" in kernel:\n        # encoder layer or adam\n        pass_name = \"opt\"\n    # before the first optimizer run, we create weight copies.\n    # they aren't part of regular processing, so they get a multiplier\n    # of zero\n    elif phase == \"bwd-enc\":\n        pass_name = \"init\"\n        multiplier = 0\n    else:\n        pass_name = phase\n        multiplier = N_LAYERS\n        time *= N_LAYERS\n        read *= N_LAYERS\n        write *= N_LAYERS\n        l2_read *= N_LAYERS\n        l2_write *= N_LAYERS\n        inst *= N_LAYERS\n\n    # split at \"(\" -- argument list\n    fn_name = kernel.split(\"(\")[0]\n    # some names include the return value, others don't?\n    if \" \" in fn_name:\n        fn_name = fn_name.split(\" \")[1]\n    if \"<\" in fn_name:\n        fn_name = fn_name.split(\"<\")[0]\n\n    # group together matmul kernels\n    if \"cutlass\" in fn_name:\n        pass\n    elif fn_name.startswith(\"ampere_bf16\"):\n        fn_name = \"ampere_bf16\"\n    elif fn_name.startswith(\"cudnn_generated_fort_native_sdpa\"):\n        fn_name = \"cudnn_generated_fort_native_sdpa\"\n    else:\n        no_cutlass += time\n\n    # convert L2 to GiB\n    l2_read = l2_read * 32 / 1024 / 1024 / 1024\n    l2_write = l2_write * 32 / 1024 / 1024 / 1024\n\n    efficiency = max(dram_bw / max_dram_bw, tensor / max_tensor)\n    summaries[fn_name] += time\n    counts[fn_name] += multiplier\n    passes[pass_name] += time\n    if pass_name != \"init\":\n        total['time'] += time\n        total['read'] += read\n        total['write'] += write\n        total['l2_read'] += l2_read\n        total['l2_write'] += l2_write\n        total['inst'] += inst\n        total['tensor'] += tensor * time # % so multiplied by time\n        total['efficiency'] += efficiency * time\n\n    pass_info = f\"{pass_name}×{multiplier}\"\n    print(f\"{kid:02} {pass_info:7} {fn_name:<40} {time:8.2f} {dram_bw:8.1f} {tensor:8.1f} {read:8.2f} {write:8.2f} {l2_read:8.2f} {l2_write:8.2f} {inst:8.2f}\")\n\n\ntotal_time = total['time']\navg_dram_bw = (total['read'] + total['write']) / (total_time / 1000.0)\navg_tensor_util = total['tensor'] / total_time\nprint(\".\" * len(units))\nprint(f\"           {'Total':<40} {total['time']:8.2f} {avg_dram_bw:8.1f} {avg_tensor_util:8.1f} {total['read']:8.2f} {total['write']:8.2f} {total['l2_read']:8.2f} {total['l2_write']:8.2f} {total['inst']:8.2f}\")\n\nprint()\nprint(\"Kernel type summaries:\")\nprint(f\"  {'name':<40} {'time':>6} {'frac':>6}  {'count':>6}\")\nordered_time = sorted(summaries.items(), key=lambda x: x[1], reverse=True)\nfor entry, value in ordered_time:\n    # crop entry to be at most 40 characters\n    if len(entry) > 40:\n        entry_text = entry[:37] + \"...\"\n    else:\n        entry_text = entry\n    print(f\"  {entry_text:<40} {value:6.2f} {100*value / total_time:6.2f}% {counts[entry]:>6d}\")\n\n\nts = total_time / 1000\nsummary = f\"\"\"\nIn total, a training step takes {total_time:.1f}ms, distributed as:\n  {passes['enc']:.1f}ms ({100 * passes['enc'] / total_time:.1f}%) in the encoder,\n  {passes['fwd']:.1f}ms ({100 * passes['fwd'] / total_time:.1f}%) in forward blocks,\n  {passes['cls']:.1f}ms ({100 * passes['cls'] / total_time:.1f}%) in the classifier part,\n  {passes['bwd']:.1f}ms ({100 * passes['bwd'] / total_time:.1f}%) in backward blocks, and\n  {passes['opt']:.1f}ms ({100 * passes['opt'] / total_time:.1f}%) in the optimizer.\n\nWe read {total['read']:.1f}GiB ({total['read']/ts:.1f}GB/s) and write {total['write']:.1f}GiB ({total['write']/ts:.1f}GB/s) to DRAM,\nread {total['l2_read']:.1f}GiB ({total['l2_read']/ts:.1f}GB/s) and write {total['l2_write']:.1f}GiB ({total['l2_write']/ts:.1f}GB/s) to L2,\nand execute {total['inst'] / 1000:.1f} billion instructions ({total['inst'] / 1000 / ts:.1f} GInst/s).\n\nAssuming that every kernel should be either fully DRAM bandwidth or tensor core limited,\nwith a peak DRAM bandwidth of {max_dram_bw:.1f}GB/s and a peak tensor throughput of {max_tensor:.1f}%,\nour overall efficiency is {(total['efficiency'] * 100.0 / total_time):.1f}%.\n\"\"\"\nprint(summary)"
  },
  {
    "path": "requirements.txt",
    "content": "tqdm\nnumpy<2\ntorch\ntiktoken\ntransformers\ndatasets\nrequests\n"
  },
  {
    "path": "scripts/README.md",
    "content": "# scripts\n\nThese shell scripts hold the exact commands to llm.c that reproduce the GPT-2 and GPT-3 runs.\n\n### pytorch reference runs\n\nFor all pyrun scripts, current restrictions:\n\n- does not write checkpoint, only logs of the train/val losses\n- does not evaluate hellaswag accuracy\n- cannot \"resume training\" (i.e. the `-y 1` flag)\n\n### memory considerations\n\nIn any of these scripts, if you are running out of memory on your GPU you'll want to meddle with two flags: the recompute setting `-r` and the microbatch size `-b`. Recompute throws away some activations during the forward pass and then recomputes them during the backward pass. This reduces the amount of memory we need to store and cache during the forward pass, but then increases the amount of computation we need to do during the backward pass. The microbatch size controls the number of token streams that are processed in a single forward/backward pass in parallel. Decreasing this number means we need to store less memory per microbatch, but then we have to increase the number of loops in the gradient accumulation to meet the same desired total batch size.\n\nLong story short, try `-r 1` (recompute GeLU, trading off speed and memory) to conserve some memory. If that doesn't help, start dividing the micro batch size until things fit. For example if the deafult is `-b 64`, try `-b 32`, and then 16, 8, etc. until things fit. Once they do fit, experiment with dialing back the recompute flag `-r 0` to get some speed back. Alternatively to `-b`, if your application doesn't need a very long context length, you can dial back the number of max sequence length using `-t`. For example GPT-2 uses `-t 1024` and GPT-3 uses `-t 2048`. Your application may tolerate a lower context length.\n\n### multi-gpu considerations\n\nIt might be that you only have one GPU and not a whole box of them. Every script is fairly easy to change for just a single GPU. For llm.c, simply change line 1 to line 2 and leave everything else the same:\n\n```bash\nmpirun -np 8 ./train_gpt2cu \\\n./train_gpt2cu \\\n```\n\nFor PyTorch, the same thing:\n\n```bash\ntorchrun --standalone --nproc_per_node=8 train_gpt2.py \\\npython train_gpt2.py \\\n```\n\nBoth of these scripts automatically detect how many GPUs are available and adjust the gradient accumulation inner loop of the optimization accordingly, so the results come out the same, up to floating point error. Of course, you'll have to wait proportionally longer for the optimization to finish.\n\nTo run on multiple nodes of GPUs, have a look at this pending [PR](https://github.com/karpathy/llm.c/pull/426), alternatively for llm.c try something like this:\n\n```bash\nmpirun -np 16 --host node1:8,node2:8 ./train_gptcu ...\n```\n\nFor PyTorch follow the torchrun docs.\n"
  },
  {
    "path": "scripts/multi_node/run_gpt2_124M_fs.sbatch",
    "content": "#!/bin/bash\n#SBATCH --job-name=llmc-multinode                                     # job name\n#SBATCH --output=/home/ubuntu/llm.c/scripts/multi_node/%x_%j_%t.log   # output file\n#SBATCH --error=/home/ubuntu/llm.c/scripts/multi_node/%x_%j_%t.err    # error file\n#SBATCH --partition=llmc                                              # Specify the GPU partition\n#SBATCH --ntasks=16                                                   # total number of processes to launch on all nodes\n#SBATCH --nodes=2                                                     # total number of nodes\n#SBATCH --ntasks-per-node=8                                           # assuming each node has 8 gpus\n#SBATCH --gres=gpu:8                                                  # request 8 gpus from each node\n\n# NOTE: change the above slurm arguments to match your system!\n# Run with `sbatch <path_to_this_script.sh>`\n\nmake train_gpt2cu USE_CUDNN=1 NO_USE_MPI=1\n\n# NOTE: change the following to match your system\nbinary_path=\"/home/ubuntu/llm.c/train_gpt2cu\"\nout_dir=\"/ephemeral/data/fineweb/log_gpt2_124M_multi\"\ntrain_data_path='/ephemeral/data/fineweb/bin_10B/fineweb_train_*.bin'\nval_data_path='/ephemeral/data/fineweb/bin_10B/fineweb_val_*.bin'\nsync_fs_path=$out_dir  # needs to be a shared filesystem path that all nodes can access\n\n# In case the file system is shared this is a no-op.\n# Otherwise, we need to copy the binary to all nodes.\ncurrent_user=$USER\nhosts=$(scontrol show hostnames $SLURM_JOB_NODELIST)  # get the hostnames of the allocated nodes\ncurrent_host=$(hostname)\nfor host in $hosts; do\n    if [ $host == $current_host ]; then\n        continue\n    fi\n    echo \"copying $binary_path to $current_user@$host\"\n    scp -r $binary_path $current_user@$host:$binary_path\ndone\n\n# Use this for NCCL debugging if you run into issues\n# export NCCL_DEBUG=INFO\n# export NCCL_DEBUG_SUBSYS=ALL\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n\n# Optimization flags\nexport NCCL_NET_GDR_LEVEL=2  # use GPUDirect RDMA - allows for direct memory access between GPUs across different nodes by bypassing the CPU\nexport NCCL_IB_DISABLE=0  # use InfiniBand if available\n\n# NOTE: change the following environment variables to match your system - or comment them out if you don't need them\nexport NCCL_SOCKET_IFNAME=ens17\nexport OMPI_MCA_btl_tcp_if_include=ens17\nexport NCCL_P2P_LEVEL=PXB\n\nif [ -z \"$SLURM_JOB_ID\" ]; then\n    echo \"Make sure you're running in a SLURM environment. Did you forget to run with sbatch? Aborting.\"\n    exit 1\nelse\n    DATESTRING=`date \"+%Y-%m-%dT%H:%M:%S\"`\n    echo \"Running in a SLURM environment (job ID: $SLURM_JOB_ID, user: $current_user)\"\n    echo \"Running on hosts: $(echo $(scontrol show hostname))\"\n    echo \"$DATESTRING\"\nfi\n\nsrun -l -u bash -c \"\n    $binary_path \\\n    -i '$train_data_path' \\\n    -j '$val_data_path' \\\n    -o $out_dir \\\n    -v 250 -s 20000 -g 144 \\\n    -h 1 \\\n    -b 64 -t 1024 \\\n    -d 2097152 \\\n    -r 0 \\\n    -z 1 \\\n    -c 0.1 \\\n    -l 0.0006 \\\n    -q 0.0 \\\n    -u 700 \\\n    -n 5000 \\\n    -y 1 \\\n    -e d12 \\\n    -pn \\$SLURM_NTASKS \\\n    -pr \\$SLURM_PROCID \\\n    -pg \\$SLURM_NTASKS_PER_NODE \\\n    -pf $sync_fs_path \\\n    -pi \"fs\" \\\n\"\n\necho \"$DATESTRING\""
  },
  {
    "path": "scripts/multi_node/run_gpt2_124M_mpi.sh",
    "content": "\nmake train_gpt2cu USE_CUDNN=1\n\n# NOTE: change the following to match your system\nbinary_path=\"/home/ubuntu/llm.c/train_gpt2cu\"\nout_dir=\"/ephemeral/data/fineweb/log_gpt2_124M_multi\"\ntrain_data_path='/ephemeral/data/fineweb/bin_10B/fineweb_train_*.bin'\nval_data_path='/ephemeral/data/fineweb/bin_10B/fineweb_val_*.bin'\n# You can find these names either in `/etc/hosts`` file or in the terminal (user@host:~$).\nhost1=\"h100-node-1-0\"  # master and worker node\nhost2=\"h100-node-1-1\"  # worker node\n\n# In case the file system is shared this is a no-op.\n# Otherwise, we need to copy the binary to all nodes.\nscp -r $binary_path $USER@$host2:$binary_path\n\n# Use this for NCCL debugging if you run into issues\n# export NCCL_DEBUG=INFO\n# export NCCL_DEBUG_SUBSYS=ALL\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n\n# Optimization flags\nexport NCCL_NET_GDR_LEVEL=2  # use GPUDirect RDMA - allows for direct memory access between GPUs across different nodes by bypassing the CPU\nexport NCCL_IB_DISABLE=0  # use InfiniBand if available\n\n# NOTE: change the following environment variables to match your system - or comment them out if you don't need them\nexport NCCL_SOCKET_IFNAME=ens17\nexport OMPI_MCA_btl_tcp_if_include=ens17\nexport NCCL_P2P_LEVEL=PXB\n\nmpirun -np 16 --host $host1:8,$host2:8 \\\n    $binary_path \\\n    -i \"$train_data_path\" \\\n    -j \"$val_data_path\" \\\n    -o $out_dir \\\n    -v 250 -s 20000 -g 144 \\\n    -h 1 \\\n    -b 64 -t 1024 \\\n    -d 2097152 \\\n    -r 0 \\\n    -z 1 \\\n    -c 0.1 \\\n    -l 0.0006 \\\n    -q 0.1 \\\n    -u 700 \\\n    -n 1000 \\\n    -y 0 \\\n    -e d12 \\\n    -pi \"mpi\" \\\n"
  },
  {
    "path": "scripts/multi_node/run_gpt2_124M_tcp.sbatch",
    "content": "#!/bin/bash\n#SBATCH --job-name=llmc-multinode                                     # job name\n#SBATCH --output=/home/ubuntu/llm.c/scripts/multi_node/%x_%j_%t.log   # output file\n#SBATCH --error=/home/ubuntu/llm.c/scripts/multi_node/%x_%j_%t.err    # error file\n#SBATCH --partition=llmc                                              # Specify the GPU partition\n#SBATCH --ntasks=16                                                   # total number of processes to launch on all nodes\n#SBATCH --nodes=2                                                     # total number of nodes\n#SBATCH --ntasks-per-node=8                                           # assuming each node has 8 gpus\n#SBATCH --gres=gpu:8                                                  # request 8 gpus from each node\n\n# NOTE: change the above slurm arguments to match your system!\n# Run with `sbatch <path_to_this_script.sh>`\n\nmake train_gpt2cu USE_CUDNN=1 NO_USE_MPI=1\n\n# NOTE: change the following to match your system\nbinary_path=\"/home/ubuntu/llm.c/train_gpt2cu\"\nout_dir=\"/ephemeral/data/fineweb/log_gpt2_124M_multi\"\ntrain_data_path='/ephemeral/data/fineweb/bin_10B/fineweb_train_*.bin'\nval_data_path='/ephemeral/data/fineweb/bin_10B/fineweb_val_*.bin'\n# NOTE: change the server_ip to the IP address of the machine that is running process zero\nserver_ip=\"10.0.1.220\"\n\n# In case the file system is shared this is a no-op.\n# Otherwise, we need to copy the binary to all nodes.\ncurrent_user=$USER\nhosts=$(scontrol show hostnames $SLURM_JOB_NODELIST)  # get the hostnames of the allocated nodes\ncurrent_host=$(hostname)\nfor host in $hosts; do\n    if [ $host == $current_host ]; then\n        continue\n    fi\n    echo \"copying $binary_path to $current_user@$host\"\n    scp -r $binary_path $current_user@$host:$binary_path\ndone\n\n# Use this for NCCL debugging if you run into issues\n# export NCCL_DEBUG=INFO\n# export NCCL_DEBUG_SUBSYS=ALL\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n\n# Optimization flags\nexport NCCL_NET_GDR_LEVEL=2  # use GPUDirect RDMA - allows for direct memory access between GPUs across different nodes by bypassing the CPU\nexport NCCL_IB_DISABLE=0  # use InfiniBand if available\n\n# NOTE: change the following environment variables to match your system - or comment them out if you don't need them\nexport NCCL_SOCKET_IFNAME=ens17\nexport OMPI_MCA_btl_tcp_if_include=ens17\nexport NCCL_P2P_LEVEL=PXB\n\nif [ -z \"$SLURM_JOB_ID\" ]; then\n    echo \"Make sure you're running in a SLURM environment. Did you forget to run with sbatch? Aborting.\"\n    exit 1\nelse\n    DATESTRING=`date \"+%Y-%m-%dT%H:%M:%S\"`\n    echo \"Running in a SLURM environment (job ID: $SLURM_JOB_ID, user: $current_user)\"\n    echo \"Running on hosts: $(echo $(scontrol show hostname))\"\n    echo \"$DATESTRING\"\nfi\n\nsrun -l -u bash -c \"\n    $binary_path \\\n    -i '$train_data_path' \\\n    -j '$val_data_path' \\\n    -o $out_dir \\\n    -v 250 -s 20000 -g 144 \\\n    -h 1 \\\n    -b 64 -t 1024 \\\n    -d 2097152 \\\n    -r 0 \\\n    -z 1 \\\n    -c 0.1 \\\n    -l 0.0006 \\\n    -q 0.0 \\\n    -u 700 \\\n    -n 5000 \\\n    -y 1 \\\n    -e d12 \\\n    -pn \\$SLURM_NTASKS \\\n    -pr \\$SLURM_PROCID \\\n    -pg \\$SLURM_NTASKS_PER_NODE \\\n    -ps $server_ip \\\n    -pi \"tcp\" \\\n\"\n\necho \"$DATESTRING\"\n"
  },
  {
    "path": "scripts/pyrun_gpt2_124M.sh",
    "content": "#!/bin/bash\n\n# the same as scripts/run_gpt2_124M.sh but with PyTorch\n\n# if you wish to train on just a single GPU, simply skip the torchrun part, i.e.\n# python train_gpt2.py ... (all the other arguments the same)\ntorchrun --standalone --nproc_per_node=8 train_gpt2.py \\\n    --input_bin \"dev/data/fineweb10B/fineweb_train_*.bin\" \\\n    --input_val_bin \"dev/data/fineweb10B/fineweb_val_*.bin\" \\\n    --val_loss_every 250 \\\n    --sample_every 0 \\\n    --output_dir pylog_gpt2_124M \\\n    --write_tensors 0 \\\n    --model d12 \\\n    --batch_size 32 \\\n    --sequence_length 1024 \\\n    --total_batch_size 524288 \\\n    --dtype bfloat16 \\\n    --compile 1 \\\n    --tensorcores 1 \\\n    --flash 1 \\\n    --num_iterations 18865 \\\n    --weight_decay 0.1 \\\n    --zero_stage 1 \\\n    --learning_rate 0.0006 \\\n    --warmup_iters 700 \\\n    --learning_rate_decay_frac 0.0 \\\n    --overfit_single_batch 0\n"
  },
  {
    "path": "scripts/run_gpt2_124M.sh",
    "content": "# GPT-2 (124M) repro on FineWeb\n# 124M parameter model on 10B tokens\n# => 6 * 124e6 * 10e9 = 7.44e18 ~= 7e18 capability model\n# 18,865 steps of 524,288 tokens/step\n# on 8X A100 80GB SXM ($14/hr) steps in ~300ms/iter\n# => training time 18,865 * 300ms = 94.3 min ~= $20\n\nmake train_gpt2cu USE_CUDNN=1\nout_dir=\"log_gpt2_124M\"\ndone_file=\"$out_dir/DONE_00018865\"\n\n# in case the training stalls or crashes, loop to resume (-y 1)\nwhile true; do\n\n    # exit condition is that optimization has finished\n    if [ -f \"$done_file\" ]; then\n        echo \"File $done_file exists. Exiting the loop.\"\n        break\n    fi\n\n    # run python dev/data/fineweb.py --version 10B to prepro data\n    # run python dev/data/hellaswag.py to prepro hellaswag eval\n    mpirun -np 8 ./train_gpt2cu \\\n                -i \"dev/data/fineweb10B/fineweb_train_*.bin\" \\\n                -j \"dev/data/fineweb10B/fineweb_val_*.bin\" \\\n                -o $out_dir \\\n                -v 250 -s 20000 -g 144 \\\n                -h 1 \\\n                -b 64 -t 1024 \\\n                -d 524288 \\\n                -r 0 \\\n                -z 1 \\\n                -c 0.1 \\\n                -l 0.0006 \\\n                -q 0.0 \\\n                -u 700 \\\n                -n 5000 \\\n                -y 1 \\\n                -e \"d12\"\n\n    sleep 1\ndone\n"
  },
  {
    "path": "scripts/run_gpt2_1558M.sh",
    "content": "# GPT-2 (1558M) repro on FineWeb-EDU\n# 1558M parameter model on 32B tokens\n# => 6 * 1558e6 * 32e9 = 6.966e20 ~= 3e20 capability model\n# 32,000 steps on ~1M tokens/step (1,048,576 to be precise)\n# on 8X H100 80GB SXM ($28/hr) steps in 2.80s/iter\n# => training time 32,000 steps * 2.7s => 24 hours ~= 1 day ~= $672\n\nmake train_gpt2cu USE_CUDNN=1\nout_dir=\"log_gpt2_1558M\"\ndone_file=\"$out_dir/DONE_00032000\"\n\n# in case the training stalls or crashes, loop to resume (-y 1)\nwhile true; do\n\n    # exit condition is that optimization has finished\n    if [ -f \"$done_file\" ]; then\n        echo \"File $done_file exists. Exiting the loop.\"\n        break\n    fi\n\n    mpirun -np 8 ./train_gpt2cu \\\n                -i \"dev/data/edu_fineweb100B/edu_fineweb_train_*.bin\" \\\n                -j \"dev/data/edu_fineweb100B/edu_fineweb_val_*.bin\" \\\n                -o $out_dir \\\n                -v 250 -s 300000 -g 384 \\\n                -h 1 \\\n                -b 16 -t 1024 \\\n                -d 1048576 \\\n                -r 0 \\\n                -z 1 \\\n                -c 0.1 \\\n                -k \"cosine\" \\\n                -l 0.0006 \\\n                -q 0.1 \\\n                -u 700 \\\n                -n 2000 \\\n                -x 32000 \\\n                -ge 1 \\\n                -y 1 \\\n                -e \"d48\"\n\n    sleep 1\ndone\n"
  },
  {
    "path": "scripts/run_gpt2_350M.sh",
    "content": "# GPT-2 (350M) repro on FineWeb\n# 350M parameter model on ~30B tokens\n# => 6 * 350e6 * 31.5e9 = 6.615e19 ~= 7e19 capability model (10X 124M)\n# 60K steps on 524,288 tokens/step\n# on 8X A100 80GB SXM ($14/hr) steps in ~820ms/iter\n# => training time 60,000 steps * 820ms = 13.7 hours ~= $200 (10X 124M)\n\nmake train_gpt2cu USE_CUDNN=1\nout_dir=\"log_gpt2_350M\"\ndone_file=\"$out_dir/DONE_00060000\"\n\n# in case the training stalls or crashes, loop to resume (-y 1)\nwhile true; do\n\n    # exit condition is that optimization has finished\n    if [ -f \"$done_file\" ]; then\n        echo \"File $done_file exists. Exiting the loop.\"\n        break\n    fi\n\n    # run python dev/data/fineweb.py --version 100B to prepro data\n    # run python dev/data/hellaswag.py to prepro hellaswag eval\n    mpirun -np 8 ./train_gpt2cu \\\n                -i \"dev/data/fineweb100B/fineweb_train_*.bin\" \\\n                -j \"dev/data/fineweb100B/fineweb_val_*.bin\" \\\n                -o $out_dir \\\n                -v 250 -s 100000 -g 144 \\\n                -h 1 \\\n                -b 64 -t 1024 \\\n                -d 524288 \\\n                -r 0 \\\n                -z 1 \\\n                -c 0.1 \\\n                -l 0.0003 \\\n                -q 0.0 \\\n                -u 700 \\\n                -n 2000 \\\n                -x 60000 \\\n                -y 1 \\\n                -e \"d24\"\n\n    sleep 1\ndone\n"
  },
  {
    "path": "scripts/run_gpt2_774M.sh",
    "content": "# GPT-2 (774M) repro on FineWeb\n# 774M parameter model on ~150B tokens\n# => 6 * 774e6 * 150e9 = 6.966e20 ~= 7e20 capability model (10X 350M)\n# => 286,102 steps on 524,288 tokens/step\n# on 8X A100 80GB SXM ($14/hr) steps in ~1.7s/iter\n# => training time 286,102 steps * 1.7s = 135 hours ~= 5.6 days ~= $2000 (10X 124M)\n\nmake train_gpt2cu USE_CUDNN=1\nout_dir=\"log_gpt2_774M\"\ndone_file=\"$out_dir/DONE_00286102\"\n\n# in case the training stalls or crashes, loop to resume (-y 1)\nwhile true; do\n\n    # exit condition is that optimization has finished\n    if [ -f \"$done_file\" ]; then\n        echo \"File $done_file exists. Exiting the loop.\"\n        break\n    fi\n\n    # run python dev/data/fineweb.py --version 100B to prepro data\n    # run python dev/data/hellaswag.py to prepro hellaswag eval\n    mpirun -np 8 ./train_gpt2cu \\\n                -i \"dev/data/fineweb100B/fineweb_train_*.bin\" \\\n                -j \"dev/data/fineweb100B/fineweb_val_*.bin\" \\\n                -o $out_dir \\\n                -v 250 -s 300000 -g 144 \\\n                -h 1 \\\n                -b 32 -t 1024 \\\n                -d 524288 \\\n                -r 0 \\\n                -z 1 \\\n                -c 0.1 \\\n                -l 0.00025 \\\n                -q 0.0 \\\n                -u 700 \\\n                -n 4000 \\\n                -x 286102 \\\n                -y 1 \\\n                -e \"d36\"\n\n    sleep 1\ndone\n"
  },
  {
    "path": "scripts/run_gpt3_125M.sh",
    "content": "# GPT-3 (125M) repro, but using FineWeb\n# 125M parameter model on 300B tokens\n# note context length: 1024 -> 2048 for GPT-3\n# => 6 * 125e6 * 300e9 = ~= 2.25e20 capability model\n# 572,204 steps of 524,288 tokens/step => 300B\n# on 8X A100 80GB SXM ($14/hr) steps in ~150ms/iter\n# => training time 572,204 * 150ms ~= 24 hours ~= $336\n\nmake train_gpt2cu USE_CUDNN=1\nout_dir=\"log_gpt3_125M\"\ndone_file=\"$out_dir/DONE_00572204\"\n\nwhile true; do\n\n    # exit condition is that optimization has finished\n    if [ -f \"$done_file\" ]; then\n        echo \"File $done_file exists. Exiting the loop.\"\n        break\n    fi\n\n    mpirun -np 8 ./train_gpt2cu \\\n                -i \"dev/data/fineweb100B/fineweb_train_*.bin\" \\\n                -j \"dev/data/fineweb100B/fineweb_val_*.bin\" \\\n                -o $out_dir \\\n                -v 250 -s 20000 -g 144 \\\n                -h 1 \\\n                -b 32 -t 2048 \\\n                -d 524288 \\\n                -r 0 \\\n                -z 1 \\\n                -c 0.1 \\\n                -l 0.0006 \\\n                -q 0.1 \\\n                -u 700 \\\n                -n 10000 \\\n                -nk 5 \\\n                -nm 50000 \\\n                -ge 1 \\\n                -sl 7.0 \\\n                -sg 7.0 \\\n                -y 1 \\\n                -x 572204 \\\n                -e \"gpt3:c768\"\n\n    sleep 1\ndone\n"
  },
  {
    "path": "test_gpt2.c",
    "content": "#define TESTING\n#include \"train_gpt2.c\"\n\n// poor man's tensor checker\nint check_tensor(float *a, float *b, int n, const char* label) {\n    int print_upto = 5;\n    int ok = 1;\n    float maxdiff = 0.0f;\n    float tol = 2e-2f;\n    printf(\"%s\\n\", label);\n    for (int i = 0; i < n; i++) {\n        // look at the diffence at position i of these two tensors\n        float diff = fabsf(a[i] - b[i]);\n\n        // keep track of the overall error\n        ok = ok && (diff <= tol);\n        if (diff > maxdiff) { maxdiff = diff; }\n\n        // for the first few elements of each tensor, pretty print\n        // the actual numbers, so we can do a visual, qualitative proof/assessment\n        if (i < print_upto) {\n            if (diff <= tol) {\n                if (i < print_upto) { printf(\"OK \"); }\n            } else {\n                if (i < print_upto) { printf(\"NOT OK \"); }\n            }\n            printf(\"%f %f\\n\", a[i], b[i]);\n        }\n    }\n    // print the final result for this tensor\n    if (ok) {\n        printf(\"TENSOR OK, maxdiff = %e\\n\", maxdiff);\n    } else {\n        printf(\"TENSOR NOT OK, maxdiff = %e\\n\", maxdiff);\n    }\n    return ok;\n}\n\nint main(int argc, char *argv[]) {\n\n    // build the GPT-2 model from a checkpoint\n    GPT2 model;\n    gpt2_build_from_checkpoint(&model, \"gpt2_124M.bin\");\n\n    int C = model.config.channels;\n    int V = model.config.vocab_size;\n    int Vp = model.config.padded_vocab_size;\n    int maxT = model.config.max_seq_len;\n    int L = model.config.num_layers;\n\n    // load additional information that we will use for debugging and error checking\n    FILE *state_file = fopen(\"gpt2_124M_debug_state.bin\", \"rb\");\n    if (state_file == NULL) { printf(\"Error opening state file\\n\"); return 1; }\n    int state_header[256];\n    freadCheck(state_header, sizeof(int), 256, state_file);\n    if (state_header[0] != 20240327) { printf(\"Bad magic state file\\n\"); return 1; }\n    if (state_header[1] != 2) {\n        printf(\"Bad version in state file\\n\");\n        printf(\"---> HINT: try to re-run `python train_gpt2.py`\\n\");\n        return 1;\n    }\n    int B = state_header[2]; // batch size, e.g. 4\n    int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT)\n    printf(\"[State]\\n\");\n    printf(\"batch_size: %d\\n\", B);\n    printf(\"seq_len: %d\\n\", T);\n\n    ParameterTensors expected_grads;\n    float* expected_grads_memory = malloc_and_point_parameters(&expected_grads, model.param_sizes);\n\n    // inputs and expected outputs, only used for error checking\n    int* x = (int*) malloc(B * T * sizeof(int));\n    int* y = (int*) malloc(B * T * sizeof(int));\n    float* expected_logits = (float*) malloc(B * T * V * sizeof(float));\n    float* expected_loss = (float*) malloc(1 * sizeof(float));\n\n    // read reference information from Python\n    freadCheck(x, sizeof(int), B*T, state_file);\n    freadCheck(y, sizeof(int), B*T, state_file);\n    freadCheck(expected_logits, sizeof(float), B*T*V, state_file);\n    freadCheck(expected_loss, sizeof(float), 1, state_file);\n    freadCheck(expected_grads_memory, sizeof(float), model.num_parameters, state_file);\n    fcloseCheck(state_file);\n\n    // overall OK signal for the test\n    int allok = 1;\n\n    // let's do 10 training iterations, following the pytorch code\n    float expected_losses[10] = {\n        5.270007133483887f,\n        4.059706687927246f,\n        3.3751230239868164f,\n        2.8007826805114746f,\n        2.315382242202759f,\n        1.8490285873413086f,\n        1.3946564197540283f,\n        0.9991465210914612f,\n        0.6240804195404053f,\n        0.37651097774505615f\n    };\n    for (int step = 0; step < 10; step++) {\n\n        struct timespec start, end;\n        clock_gettime(CLOCK_MONOTONIC, &start);\n\n        gpt2_forward(&model, x, y, B, T);\n        gpt2_zero_grad(&model);\n        gpt2_backward(&model);\n\n        clock_gettime(CLOCK_MONOTONIC, &end);\n        double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;\n\n        if (step == 0) {\n            // error checking at step 0 for reference activations/gradients\n            // at this point, target should be equal to expected_logits, let's compare\n            int logits_ok = 1;\n            float* calculated_logits = model.acts.logits;\n            float max_diff = 0.0f;\n            for (int bt = 0; bt < B*T; bt++) {\n                for (int v = 0; v < V; v++) { // note we only loop to V (ignoring padding)\n                    int i = bt * Vp + v; // linearized index, using Vp\n                    if (i < 10) {\n                        printf(\"%f, %f\\n\", expected_logits[i], calculated_logits[i]);\n                    }\n                    float diff = fabsf(expected_logits[bt*V + v] - calculated_logits[i]);\n                    max_diff = fmaxf(max_diff, diff);\n                    if (diff >= 1e-2f) {\n                        printf(\"MISMATCH AT INDEX %d,%d: \", bt, v);\n                        printf(\"%f %f\\n\", expected_logits[bt*V + v], calculated_logits[i]);\n                        logits_ok = 0;\n                        bt = B*T; // to break out of both loops\n                        break;\n                    }\n                }\n            }\n            if(!logits_ok) { printf(\"NOT \"); }\n            printf(\"OK (LOGITS), max_diff = %e\\n\", max_diff);\n            allok = allok && logits_ok;\n\n            // compare the achieved loss\n            if (fabsf(model.mean_loss - *expected_loss) >= 1e-2) {\n                printf(\"LOSS MISMATCH: %f %f\\n\", model.mean_loss, *expected_loss);\n                allok = 0;\n            } else {\n                printf(\"LOSS OK: %f %f\\n\", model.mean_loss, *expected_loss);\n            }\n\n            // finally check all the gradients\n            int gradoks[16];\n            ParameterTensors grads = model.grads;\n            gradoks[0] = check_tensor(grads.wte, expected_grads.wte, V*C, \"dwte\");\n            gradoks[1] = check_tensor(grads.wpe, expected_grads.wpe, maxT*C, \"dwpe\");\n            gradoks[2] = check_tensor(grads.ln1w, expected_grads.ln1w, L*C, \"dln1w\");\n            gradoks[3] = check_tensor(grads.ln1b, expected_grads.ln1b, L*C, \"dln1b\");\n            gradoks[4] = check_tensor(grads.qkvw, expected_grads.qkvw, L*3*C*C, \"dqkvw\");\n            gradoks[5] = check_tensor(grads.qkvb, expected_grads.qkvb, L*3*C, \"dqkvb\");\n            gradoks[6] = check_tensor(grads.attprojw, expected_grads.attprojw, L*C*C, \"dattprojw\");\n            gradoks[7] = check_tensor(grads.attprojb, expected_grads.attprojb, L*C, \"dattprojb\");\n            gradoks[8] = check_tensor(grads.ln2w, expected_grads.ln2w, L*C, \"dln2w\");\n            gradoks[9] = check_tensor(grads.ln2b, expected_grads.ln2b, L*C, \"dln2b\");\n            gradoks[10] = check_tensor(grads.fcw, expected_grads.fcw, L*4*C*C, \"dfcw\");\n            gradoks[11] = check_tensor(grads.fcb, expected_grads.fcb, L*4*C, \"dfcb\");\n            gradoks[12] = check_tensor(grads.fcprojw, expected_grads.fcprojw, L*C*4*C, \"dfcprojw\");\n            gradoks[13] = check_tensor(grads.fcprojb, expected_grads.fcprojb, L*C, \"dfcprojb\");\n            gradoks[14] = check_tensor(grads.lnfw, expected_grads.lnfw, C, \"dlnfw\");\n            gradoks[15] = check_tensor(grads.lnfb, expected_grads.lnfb, C, \"dlnfb\");\n            for (int i = 0; i < 16; i++) {\n                allok = allok && gradoks[i];\n            }\n        }\n\n        gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.01f, step+1);\n\n        // compare the losses\n        float expected_loss = expected_losses[step];\n        float actual_loss = model.mean_loss;\n        int step_loss_ok = fabsf(expected_loss - actual_loss) < 1e-2;\n        allok = allok && step_loss_ok;\n\n        // print the timing information at the end\n        printf(\"step %d: loss %f (took %f ms) OK = %d\\n\", step, model.mean_loss, time_elapsed_s * 1000, step_loss_ok);\n    }\n\n    // final judgement\n    printf(\"overall okay: %d\\n\", allok);\n\n    // free everything\n    free(x);\n    free(y);\n    free(expected_logits);\n    free(expected_loss);\n    free(expected_grads_memory);\n    gpt2_free(&model);\n    return 0;\n}\n"
  },
  {
    "path": "test_gpt2.cu",
    "content": "#define TESTING\n#include \"train_gpt2.cu\"\n\n// poor man's tensor checker\nint check_tensor(float *a, float *b, int n, const char* label, float threshold=1e-0) {\n    // a is the calculated tensor, b is the reference tensor\n    int print_upto = 10;\n    int ok = 1;\n    float max_diff = 0.0f;\n    float max_rel_error = 0.0f;\n    float max_to_threshold = 0.f;\n    float max_a = 0.0f;\n    float max_b = 0.0f;\n    float epsilon = 0.079;      // BF16 epsilon value\n    printf(\"---\\n\");\n    printf(\"checking tensor: %s\\n\", label);\n    for (int i = 0; i < n; i++) {\n        float t_eff = threshold + fabs(b[i]) * epsilon;\n        float diff = fabsf(a[i] - b[i]);\n        max_to_threshold = max(max_to_threshold, diff / t_eff);\n        if (diff > max_diff) {\n            max_diff = diff;\n            float denom = fabsf(b[i]);\n            max_rel_error = (denom == 0.0f) ? 0.0f : diff / denom;\n            max_a = a[i];\n            max_b = b[i];\n        }\n        if (diff > t_eff) {\n            ok = 0;\n        }\n        // print the first few elements so we can visually assess the \"proof\" of the comparison\n        if (i < print_upto) {\n            printf(diff <= t_eff ? \"OK \" :  \"NOT OK \");\n            printf(\"%f %f\\n\", a[i], b[i]);\n        }\n    }\n    // print the final result\n    if (ok) {\n        printf(\"TENSOR OK, max diff: %.3e, with rel error: %.3e (calculated=%10f, ref=%10f), %.2f%% of maximum error\\n\",\n                max_diff, max_rel_error, max_a, max_b, max_to_threshold*100);\n    } else {\n        printf(\"TENSOR NOT OK, max diff: %.3e, with rel error: %.3e (calculated=%10f, ref=%10f), %.2f%% of maximum error\\n\",\n                max_diff, max_rel_error, max_a, max_b, max_to_threshold*100);\n    }\n    return ok;\n}\n\n// the same tensors as in the train file, but in float, which are used as reference\ntypedef struct {\n    float*  wte; // (Vp, C)\n    float*  wpe; // (maxT, C)\n    float*  ln1w; // (L, C)\n    float*  ln1b; // (L, C)\n    float*  qkvw; // (L, 3*C, C)\n    float*  qkvb; // (L, 3*C)\n    float*  attprojw; // (L, C, C)\n    float*  attprojb; // (L, C)\n    float*  ln2w; // (L, C)\n    float*  ln2b; // (L, C)\n    float*  fcw; // (L, 4*C, C)\n    float*  fcb; // (L, 4*C)\n    float*  fcprojw; // (L, C, 4*C)\n    float*  fcprojb; // (L, C)\n    float*  lnfw; // (C)\n    float*  lnfb; // (C)\n} FloatParameterTensors;\nstatic_assert(sizeof(FloatParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), \"Inconsistent sizes!\");\n\n// malloc_and_point, but in float and on CPU, because we use this data to check correctness on CPU\nfloat* float_cpu_malloc_and_point_parameters(FloatParameterTensors* params, size_t* param_sizes) {\n    // calculate the total number of parameters\n    size_t num_parameters = 0;\n    for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {\n        num_parameters += param_sizes[i];\n    }\n    // everything is float so number of bytes to allocate is a simple multiplication\n    float* params_memory = (float*)mallocCheck(num_parameters * sizeof(float));\n    float** ptrs[] = {\n        &params->wte, &params->wpe, &params->ln1w, &params->ln1b, &params->qkvw, &params->qkvb,\n        &params->attprojw, &params->attprojb, &params->ln2w, &params->ln2b, &params->fcw, &params->fcb,\n        &params->fcprojw, &params->fcprojb, &params->lnfw, &params->lnfb\n    };\n    float* params_memory_iterator = params_memory;\n    for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {\n        *(ptrs[i]) = params_memory_iterator;\n        params_memory_iterator += param_sizes[i];\n    }\n    return params_memory;\n}\n\nint main(int argc, char *argv[]) {\n    char nccl_init_method[256] = \"mpi\";  // \"tcp\" or \"fs\" or \"mpi\"\n    int num_processes = -1;  // doesn't matter when using MPI\n    int process_rank = -1;  // doesn't matter when using MPI\n    int gpus_per_node = -1;  // doesn't matter when using MPI\n    char server_ip[256] = \"\";  // doesn't matter when using MPI\n    char fs_path[256] = \"\";  // doesn't matter when using MPI\n    multi_gpu_config = multi_gpu_config_init(num_processes, process_rank, gpus_per_node, server_ip, fs_path, nccl_init_method);\n    common_start(false, true);\n\n    // set the right paths\n    #if defined(ENABLE_BF16)\n    const char* load_filename = \"gpt2_124M_bf16.bin\";\n    #else\n    const char* load_filename = \"gpt2_124M.bin\";\n    #endif\n\n    // build the GPT-2 model from a checkpoint\n    GPT2 model;\n    gpt2_init_common(&model);\n\n    gpt2_build_from_checkpoint(&model, load_filename);\n    size_t V = model.config.vocab_size;\n    size_t Vp = model.config.padded_vocab_size;\n    size_t maxT = model.config.max_seq_len;\n\n    for (int i = 1; i < argc; i+=2) {\n        if (i + 1 >= argc) { exit(EXIT_FAILURE);  } // must have arg after flag\n        if (!(strlen(argv[i]) == 2 || strlen(argv[i]) == 3)) { exit(EXIT_FAILURE); } // must be -x[y] (one dash, one or two letters)\n        if (argv[i][0] != '-') { exit(EXIT_FAILURE); } // must start with dash\n        if (argv[i][1] == 'w') { model.use_master_weights = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'r') { model.recompute = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'g' && argv[i][2] == 'e') { model.gelu_fusion = atoi(argv[i+1]); }\n    }\n\n    // load additional information that we will use for debugging and error checking\n    FILE *state_file = fopenCheck(\"gpt2_124M_debug_state.bin\", \"rb\");\n    int state_header[256];\n    freadCheck(state_header, sizeof(int), 256, state_file);\n    if (state_header[0] != 20240327) { fprintf(stderr, \"Bad magic state file\\n\"); exit(EXIT_FAILURE); }\n    if (state_header[1] != 2) {\n        fprintf(stderr, \"Bad version in state file\\n\");\n        fprintf(stderr, \"---> HINT: try to re-run `python train_gpt2.py`\\n\");\n        exit(EXIT_FAILURE);\n    }\n    int B = state_header[2]; // batch size, e.g. 4\n    int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT)\n    assert(0 <= T && T <= maxT);\n    printf(\"[State]\\n\");\n    printf(\"batch_size: %d\\n\", B);\n    printf(\"seq_len: %d\\n\", T);\n\n    set_zero_configs(&multi_gpu_config, 0, model.num_parameters);\n\n    // read reference information from the file saved from Python/PyTorch side\n    // 1) input x and y\n    int* x = (int*)mallocCheck(B * T * sizeof(int));\n    int* y = (int*)mallocCheck(B * T * sizeof(int));\n    freadCheck(x, sizeof(int), B*T, state_file);\n    freadCheck(y, sizeof(int), B*T, state_file);\n    // 2) results of forward pass (logits and loss)\n    float* expected_logits = (float*) mallocCheck(B * T * V * sizeof(float));\n    float* expected_loss = (float*) mallocCheck(1 * sizeof(float));\n    freadCheck(expected_logits, sizeof(float), B*T*V, state_file);\n    freadCheck(expected_loss, sizeof(float), 1, state_file);\n    // 3) results of backward pass (parameter gradients)\n    FloatParameterTensors expected_grads; // will be read from file. right now: all in fp32\n    float* expected_grads_memory = float_cpu_malloc_and_point_parameters(&expected_grads, model.param_elements);\n    freadCheck(expected_grads_memory, sizeof(float), model.num_parameters, state_file);\n    fcloseCheck(state_file);\n\n    // this memory will be used to do one single copy of all (mixed precision) GPU grads to CPU grads\n    void* grads_memory_cpu = mallocCheck(model.num_parameters_bytes);\n    float* grads_memory_cpu_float = (float*)mallocCheck(model.num_parameters * sizeof(float));\n\n    // overall OK signal for the test\n    int allok = 1;\n\n    gpt2_allocate_state(&model, B, T);\n\n    // First, do target-free forward pass to validate logits\n    gpt2_forward(&model, x, B, T);\n    // at this point, target should be equal to expected_logits, let's compare\n    // copy logits to CPU so we can compare them\n    floatX* logits_cpu_raw = (floatX*)mallocCheck(B * T * Vp * sizeof(floatX));\n    float* logits_cpu = (float*)mallocCheck(B * T * Vp * sizeof(float));\n    cudaCheck(cudaMemcpy(logits_cpu_raw, model.acts.output, B * T * Vp * sizeof(floatX), cudaMemcpyDeviceToHost));\n    for (int i = 0; i < B * T * Vp; i++) {\n        logits_cpu[i] = (float)logits_cpu_raw[i];\n    }\n\n    float logit_accuracy_threshold = 1e-3f;\n    float loss_diff_threshold = 1e-5f;\n    // FP16 and lower require very high tolerances unfortunately. TODO look into more\n    #if defined(ENABLE_BF16) || defined(ENABLE_F16)\n    logit_accuracy_threshold = 25.0f; // 15.0f was too low even without cuDNN?! :(\n    loss_diff_threshold = 0.05f;\n    #endif\n\n    // compare the output logits from the forward pass\n    // also careful that we don't access and compare the padded columns of logits\n    int logits_ok = 1;\n    float max_diff = 0.0f;\n    for (int bt = 0; bt < B*T; bt++) {\n        for (int v = 0; v < V; v++) {\n            int i = bt * Vp + v; // linearized index\n            if (i < 10) {\n                printf(\"%f, %f\\n\", expected_logits[i], logits_cpu[i]);\n            }\n            float diff = fabsf(expected_logits[bt*V + v] - logits_cpu[i]);\n            max_diff = fmaxf(max_diff, diff);\n            if (diff >= logit_accuracy_threshold) {\n                printf(\"MISMATCH AT INDEX %d,%d: \", bt, v);\n                printf(\"%f %f\\n\", expected_logits[bt*V + v], logits_cpu[i]);\n                logits_ok = 0;\n                bt = B*T; // to break out of both loops\n                break;\n            }\n        }\n    }\n    allok = allok && logits_ok;\n    if(!logits_ok) { printf(\"NOT \"); }\n    printf(\"OK (LOGITS)\\n\");\n    printf(\"logit max diff: %f\\n\", max_diff);\n\n    // let's do 10 training iterations, following the pytorch code\n    float losses[10];\n    for (int step = 0; step < 10; step++) {\n        struct timespec start, end;\n        clock_gettime(CLOCK_MONOTONIC, &start);\n        gpt2_forward(&model, x, B, T);\n        gpt2_backward_and_reduce(&model, x, y, 1, 0);\n        clock_gettime(CLOCK_MONOTONIC, &end);\n        double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;\n\n        if (step == 0) {\n            // error checking at step 0 for reference activations\n\n            // move the (mixed precision) grads from GPU to CPU\n            cudaCheck(cudaMemcpy(grads_memory_cpu, model.grads_memory, model.num_parameters_bytes, cudaMemcpyDeviceToHost));\n\n            // convert all gradients to float on the CPU\n            char* src_iterator = (char*)grads_memory_cpu; // can be lower precision, so we use char*\n            float* dst_iterator = (float*)grads_memory_cpu_float; // float*\n            float* exp_iterator = expected_grads_memory; // float* of expected gradients from Python\n            float* tensors1[NUM_PARAMETER_TENSORS];\n            float* tensors2[NUM_PARAMETER_TENSORS];\n            for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {\n                if (model.param_sizeof[i] == sizeof(float)) {\n                    // float tensor => copy over directly\n                    memcpy(dst_iterator, src_iterator, model.param_elements[i] * sizeof(float));\n                } else {\n                    // low-precision tensor => convert to float\n                    assert(model.param_sizeof[i] == sizeof(floatX)); // floatX is the single non-float supported atm\n                    for (size_t j = 0; j < model.param_elements[i]; j++) {\n                        dst_iterator[j] = ((floatX*)src_iterator)[j]; // convert to float\n                    }\n                }\n                // for convenience record the position of comparison for reality vs. expectation\n                tensors1[i] = dst_iterator; // reality\n                tensors2[i] = exp_iterator; // expectation\n                // advance the iterators\n                src_iterator += model.param_elements[i] * model.param_sizeof[i];\n                dst_iterator += model.param_elements[i];\n                exp_iterator += model.param_elements[i];\n            }\n\n            // compare the gradients on the parameters all at once, in fp32\n            // I set the tolerances manually by inspecting the gradient differences for\n            // a few elements of each tensor. bf16 looks ok but not amazing here.\n            // It's possible we have bugs lurking, or maybe it is bf16. Not 100% sure.\n            // Also, if code changes and some of these get tripped, it could be ok if it's not by too much,\n            // because our use of stochastic rounding is adding some non-determinism \"pepper noise\".\n            // In that case it's ok to extend the tolerance by a bit, after a manual review.\n            // Also, different GPUs may use different matrix multiplication algorithms, so the\n            // actual errors can be hardware specific.\n\n            float grad_thresholds[NUM_PARAMETER_TENSORS] = {\n                    5e-1f, 4e-3f, 1e-1f, 4e-2f,\n                    5e-2f, 3.5e-2f, 2e-2f, 3e-2f,\n                    5e-2f, 3e-2f, 3e-2f, 3e-2f,\n                    2e-2f, 1e-2f,1e-1f,2e-2f};\n\n            #if defined(ENABLE_FP32)\n            for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {\n                grad_thresholds[i] = 1e-6f;  // we can be much more precise in FP32\n            }\n            #endif\n            const char* names[NUM_PARAMETER_TENSORS] = {\n                    \"wte\", \"wpe\", \"ln1w\", \"ln1b\", \"qkvw\", \"qkvb\", \"attrpojw\",\n                    \"attprojb\", \"ln2w\", \"ln2b\", \"fcw\", \"fcb\", \"fcprojw\", \"fcprojb\",\n                    \"lnfw\", \"lnfb\"\n            };\n            size_t* count = model.param_elements;\n            for(int i = 0; i < NUM_PARAMETER_TENSORS; ++i) {\n                allok = allok & check_tensor(tensors1[i], tensors2[i], count[i], names[i], grad_thresholds[i]);\n            }\n        }\n\n        float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config);\n        float grad_scale = (grad_norm > 1.0f) ? 1.0f / grad_norm : 1.0f;\n        gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, grad_scale, step+1, &multi_gpu_config);\n\n        // print the timing information at the end\n        printf(\"step %d: loss %f (took %f ms)\\n\", step+1, model.mean_loss, time_elapsed_s * 1000);\n        // the expected losses from PyTorch were copied over after the print formatting rounded\n        // them to 6 decimal places, so we do the same here\n        float rounded_loss = roundf(model.mean_loss * 1000000) / 1000000;\n        losses[step] = rounded_loss;\n    }\n\n    // expected losses are as follows, from Python\n    float expected_losses[10] = {\n        5.270009f,\n        4.060681f,\n        3.320085f,\n        2.717550f,\n        2.181066f,\n        1.653923f,\n        1.168050f,\n        0.736873f,\n        0.401021f,\n        0.187493f\n    };\n\n    // compare\n    for (int i = 0; i < 10; i++) {\n        if (fabsf(losses[i] - expected_losses[i]) >= loss_diff_threshold) {\n            printf(\"LOSS MISMATCH AT STEP %d: %f %f\\n\", i+1, losses[i], expected_losses[i]);\n            allok = 0;\n        } else {\n            printf(\"loss ok at step %d: %f %f\\n\", i+1, losses[i], expected_losses[i]);\n        }\n    }\n\n    // Finally, let's check determinism\n    gpt2_write_to_checkpoint(&model, \"test_gpt2cu_model.ckpt\");\n\n    DataLoader loader;\n    dataloader_init(&loader, \"dev/data/tinyshakespeare/tiny_shakespeare_val.bin\", B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes, 1);\n    save_state(\"test_gpt2cu_state.ckpt\", 10, &model, &loader);\n    int tokens[10];\n    for (int step = 0; step < 10; step++) {\n        dataloader_next_batch(&loader);\n        gpt2_forward(&model, loader.inputs, B, T);\n        gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, 0);\n        gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config);\n        losses[step] = model.mean_loss;\n        tokens[step] = loader.inputs[0];\n    }\n\n    // reload\n    gpt2_free(&model);\n    gpt2_build_from_checkpoint(&model, \"test_gpt2cu_model.ckpt\");\n    int ld_step;\n    gpt2_allocate_state(&model, B, T);\n    load_state(&ld_step, &model, &loader, \"test_gpt2cu_state.ckpt\");\n    for (int step = 0; step < 10; step++) {\n        dataloader_next_batch(&loader);\n        gpt2_forward(&model, loader.inputs, B, T);\n        gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, 0);\n        gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config);\n\n        if(loader.inputs[0] != tokens[step]) {\n            printf(\"Nondeterminism! Token mismatch at step %d: %d vs %d\\n\", step, tokens[step], loader.inputs[0]);\n            allok = false;\n            break;\n        }\n\n        if(losses[step] != model.mean_loss) {\n            printf(\"Nondeterminism! Loss mismatch at step %d: %.15f vs %.15f\\n\", step, losses[step], model.mean_loss);\n            allok = false;\n            break;\n        } else {\n            printf(\"loss ok at step %d: %f %f\\n\", step, losses[step], model.mean_loss);\n        }\n    }\n\n    // final approval\n    printf(\"overall okay: %d\\n\", allok);\n\n    // delete intermediate test files\n    remove(\"test_gpt2cu_model.ckpt\");\n    remove(\"test_gpt2cu_state.ckpt\");\n\n    // free everything\n    dataloader_free(&loader);\n    gpt2_free(&model);\n    common_free(model);\n    free(x);\n    free(y);\n    free(logits_cpu_raw);\n    free(logits_cpu);\n    free(expected_logits);\n    free(expected_loss);\n    free(expected_grads_memory);\n    free(grads_memory_cpu);\n    free(grads_memory_cpu_float);\n    return allok ? EXIT_SUCCESS : EXIT_FAILURE;\n}\n"
  },
  {
    "path": "test_gpt2_fp32.cu",
    "content": "#define TESTING\n#include \"train_gpt2_fp32.cu\"\n\n// poor man's tensor checker\nint check_tensor(float *a, float *b, int n, const char* label) {\n    int print_upto = 5;\n    int ok = 1;\n    printf(\"%s\\n\", label);\n    for (int i = 0; i < n; i++) {\n        if (fabsf(a[i] - b[i]) <= 1e-2) {\n            if (i < print_upto) { printf(\"OK \"); }\n        } else {\n            if (i < print_upto) { printf(\"NOT OK \"); }\n            ok = 0;\n        }\n        if (i < print_upto) { printf(\"%f %f\\n\", a[i], b[i]); }\n    }\n    // print the final result\n    if (ok) {\n        printf(\"TENSOR OK\\n\");\n    } else {\n        printf(\"TENSOR NOT OK\\n\");\n    }\n    return ok;\n}\n\nint main(int argc, char *argv[]) {\n\n    // set up the device\n    int deviceIdx = 0;\n    cudaCheck(cudaSetDevice(deviceIdx));\n    cudaDeviceProp deviceProp;\n    cudaGetDeviceProperties(&deviceProp, deviceIdx);\n    printf(\"[System]\\n\");\n    printf(\"Device %d: %s\\n\", deviceIdx, deviceProp.name);\n\n    // setup cuBLAS and cuBLASLt\n    cublasCheck(cublasCreate(&cublas_handle));\n    // TF32 precision is equivalent to torch.set_float32_matmul_precision('high')\n    int enable_tf32 = deviceProp.major >= 8 ? 1 : 0;\n    enable_tf32 = 0; // NOTE: disable TF32 for testing!!!\n    printf(\"enable_tf32: %d\\n\", enable_tf32);\n    cublas_compute_type = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F;\n    cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH;\n    cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode));\n\n    // build the GPT-2 model from a checkpoint\n    GPT2 model;\n    gpt2_build_from_checkpoint(&model, \"gpt2_124M.bin\");\n\n    // int C = model.config.channels;\n    int V = model.config.vocab_size;\n    int Vp = model.config.padded_vocab_size;\n    int maxT = model.config.max_seq_len;\n    // int L = model.config.num_layers;\n\n    // load additional information that we will use for debugging and error checking\n    FILE *state_file = fopenCheck(\"gpt2_124M_debug_state.bin\", \"rb\");\n    int state_header[256];\n    freadCheck(state_header, sizeof(int), 256, state_file);\n    if (state_header[0] != 20240327) { printf(\"Bad magic state file\\n\"); exit(EXIT_FAILURE); }\n    if (state_header[1] != 2) {\n        fprintf(stderr, \"Bad version in state file\\n\");\n        fprintf(stderr, \"---> HINT: try to re-run `python train_gpt2.py`\\n\");\n        exit(EXIT_FAILURE);\n    }\n    int B = state_header[2]; // batch size, e.g. 4\n    int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT)\n    assert(0 <= T && T <= maxT);\n    printf(\"[State]\\n\");\n    printf(\"batch_size: %d\\n\", B);\n    printf(\"seq_len: %d\\n\", T);\n\n    ParameterTensors expected_grads; // will be read from file (from PyTorch)\n    ParameterTensors calculated_grads; // will be calculated by us\n    float* expected_grads_memory = malloc_and_point_parameters(&expected_grads, model.param_sizes, 0);\n    float* calculated_grads_memory = malloc_and_point_parameters(&calculated_grads, model.param_sizes, 0);\n\n    // inputs and expected outputs, only used for error checking\n    int* x = (int*)mallocCheck(B * T * sizeof(int));\n    int* y = (int*)mallocCheck(B * T * sizeof(int));\n    float* expected_logits = (float*) mallocCheck(B * T * V * sizeof(float));\n    float* expected_loss = (float*) mallocCheck(1 * sizeof(float));\n\n    // read reference information from Python\n    freadCheck(x, sizeof(int), B*T, state_file);\n    freadCheck(y, sizeof(int), B*T, state_file);\n    freadCheck(expected_logits, sizeof(float), B*T*V, state_file);\n    freadCheck(expected_loss, sizeof(float), 1, state_file);\n    freadCheck(expected_grads_memory, sizeof(float), model.num_parameters, state_file);\n    fcloseCheck(state_file);\n\n    // overall OK signal for the test\n    int allok = 1;\n\n    // First, do target-free forward pass to validate logits\n    gpt2_forward(&model, x, NULL, B, T);\n    // at this point, target should be equal to expected_logits, let's compare\n    // copy logits to CPU so we can compare them\n    float* logits_cpu = (float*)mallocCheck(B * T * Vp * sizeof(float));\n    cudaCheck(cudaMemcpy(logits_cpu, model.acts.output, B * T * Vp * sizeof(float), cudaMemcpyDeviceToHost));\n\n    // compare the output logits from the forward pass\n    // also careful that we don't access and compare the padded columns of logits\n    int logits_ok = 1;\n    float max_diff = 0.0f;\n    for (int bt = 0; bt < B*T; bt++) {\n        for (int v = 0; v < V; v++) {\n            int i = bt * Vp + v; // linearized index\n            if (i < 10) {\n                printf(\"%f, %f\\n\", expected_logits[i], logits_cpu[i]);\n            }\n            float diff = fabsf(expected_logits[bt*V + v] - logits_cpu[i]);\n            max_diff = fmaxf(max_diff, diff);\n            if (diff >= 1e-2f) {\n                printf(\"MISMATCH AT INDEX %d,%d: \", bt, v);\n                printf(\"%f %f\\n\", expected_logits[bt*V + v], logits_cpu[i]);\n                logits_ok = 0;\n                bt = B*T; // to break out of both loops\n                break;\n            }\n        }\n    }\n    allok = allok && logits_ok;\n    if(!logits_ok) { printf(\"NOT \"); }\n    printf(\"OK (LOGITS)\\n\");\n\n    // let's do 10 training iterations, following the pytorch code\n    float losses[10];\n    for (int step = 0; step < 10; step++) {\n        struct timespec start, end;\n        clock_gettime(CLOCK_MONOTONIC, &start);\n        gpt2_forward(&model, x, y, B, T);\n        gpt2_zero_grad(&model);\n        gpt2_backward(&model);\n        clock_gettime(CLOCK_MONOTONIC, &end);\n        double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;\n\n        if (step == 0) {\n            // error checking at step 0 for reference activations\n            free(logits_cpu);\n\n            // compare the achieved loss\n            if (fabsf(model.mean_loss - *expected_loss) >= 1e-2) {\n                printf(\"LOSS MISMATCH: %f %f\\n\", model.mean_loss, *expected_loss);\n                allok = 0;\n            } else {\n                printf(\"LOSS OK: %f %f\\n\", model.mean_loss, *expected_loss);\n            }\n\n            // and now compare the gradients on the parameters\n            // cudaMemcpy(calculated_grads.lnfw, model.grads.lnfw, C * sizeof(float), cudaMemcpyDeviceToHost);\n            // cudaMemcpy(calculated_grads.lnfb, model.grads.lnfb, C * sizeof(float), cudaMemcpyDeviceToHost);\n            // cudaMemcpy(calculated_grads.fcprojw, model.grads.fcprojw, L * C * 4*C * sizeof(float), cudaMemcpyDeviceToHost);\n            // cudaMemcpy(calculated_grads.fcprojb, model.grads.fcprojb, L * C * sizeof(float), cudaMemcpyDeviceToHost);\n            // cudaMemcpy(calculated_grads.fcw, model.grads.fcw, L * 4*C * C * sizeof(float), cudaMemcpyDeviceToHost);\n            // cudaMemcpy(calculated_grads.fcb, model.grads.fcb, L * 4*C * sizeof(float), cudaMemcpyDeviceToHost);\n            // cudaMemcpy(calculated_grads.ln2w, model.grads.ln2w, L * C * sizeof(float), cudaMemcpyDeviceToHost);\n            // cudaMemcpy(calculated_grads.ln2b, model.grads.ln2b, L * C * sizeof(float), cudaMemcpyDeviceToHost);\n            // cudaMemcpy(calculated_grads.attprojw, model.grads.attprojw, L * C * C * sizeof(float), cudaMemcpyDeviceToHost);\n            // cudaMemcpy(calculated_grads.attprojb, model.grads.attprojb, L * C * sizeof(float), cudaMemcpyDeviceToHost);\n            // cudaMemcpy(calculated_grads.qkvw, model.grads.qkvw, L * 3*C * C * sizeof(float), cudaMemcpyDeviceToHost);\n            // cudaMemcpy(calculated_grads.qkvb, model.grads.qkvb, L * 3*C * sizeof(float), cudaMemcpyDeviceToHost);\n            // cudaMemcpy(calculated_grads.ln1w, model.grads.ln1w, L * C * sizeof(float), cudaMemcpyDeviceToHost);\n            // cudaMemcpy(calculated_grads.ln1b, model.grads.ln1b, L * C * sizeof(float), cudaMemcpyDeviceToHost);\n            // cudaMemcpy(calculated_grads.wte, model.grads.wte, V * C * sizeof(float), cudaMemcpyDeviceToHost);\n            // cudaMemcpy(calculated_grads.wpe, model.grads.wpe, maxT * C * sizeof(float), cudaMemcpyDeviceToHost);\n            // check_tensor(calculated_grads.lnfb, expected_grads.lnfb, C, \"lnfb\");\n            // check_tensor(calculated_grads.lnfw, expected_grads.lnfw, C, \"lnfw\");\n            // check_tensor(calculated_grads.fcprojw, expected_grads.fcprojw, L * C * 4*C, \"fcprojw\");\n            // check_tensor(calculated_grads.fcprojb, expected_grads.fcprojb, L * C, \"fcprojb\");\n            // check_tensor(calculated_grads.fcw, expected_grads.fcw, L * 4*C * C, \"fcw\");\n            // check_tensor(calculated_grads.fcb, expected_grads.fcb, L * 4*C, \"fcb\");\n            // check_tensor(calculated_grads.ln2w, expected_grads.ln2w, L * C, \"ln2w\");\n            // check_tensor(calculated_grads.ln2b, expected_grads.ln2b, L * C, \"ln2b\");\n            // check_tensor(calculated_grads.attprojw, expected_grads.attprojw, L * C * C, \"attprojw\");\n            // check_tensor(calculated_grads.attprojb, expected_grads.attprojb, L * C, \"attprojb\");\n            // check_tensor(calculated_grads.qkvw, expected_grads.qkvw, L * 3*C * C, \"qkvw\");\n            // check_tensor(calculated_grads.qkvb, expected_grads.qkvb, L * 3*C, \"qkvb\");\n            // check_tensor(calculated_grads.ln1w, expected_grads.ln1w, L * C, \"ln1w\");\n            // check_tensor(calculated_grads.ln1b, expected_grads.ln1b, L * C, \"ln1b\");\n            // check_tensor(calculated_grads.wte, expected_grads.wte, V * C, \"wte\");\n            // check_tensor(calculated_grads.wpe, expected_grads.wpe, maxT * C, \"wpe\");\n\n            // compare the gradients ona the parameters all at once\n            cudaMemcpy(calculated_grads_memory, model.grads_memory, model.num_parameters * sizeof(float), cudaMemcpyDeviceToHost);\n            check_tensor(calculated_grads_memory, expected_grads_memory, model.num_parameters, \"grads\");\n        }\n\n        gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.01f, step+1);\n\n        // print the timing information at the end\n        printf(\"step %d: loss %f (took %f ms)\\n\", step, model.mean_loss, time_elapsed_s * 1000);\n        losses[step] = model.mean_loss;\n    }\n\n    // expected losses are as follows, from Python\n    float expected_losses[10] = {\n        5.270007133483887f,\n        4.059706687927246f,\n        3.3751230239868164f,\n        2.8007826805114746f,\n        2.315382242202759f,\n        1.8490285873413086f,\n        1.3946564197540283f,\n        0.9991465210914612f,\n        0.6240804195404053f,\n        0.37651097774505615f\n    };\n\n    // compare\n    for (int i = 0; i < 10; i++) {\n        if (fabsf(losses[i] - expected_losses[i]) >= 1e-2) {\n            printf(\"LOSS MISMATCH AT STEP %d: %f %f\\n\", i, losses[i], expected_losses[i]);\n            allok = 0;\n        } else {\n            printf(\"loss ok at step %d: %f %f\\n\", i, losses[i], expected_losses[i]);\n        }\n    }\n\n    // final approval\n    printf(\"overall okay: %d\\n\", allok);\n\n    // free everything\n    free(x);\n    free(y);\n    free(expected_logits);\n    free(expected_loss);\n    free(expected_grads_memory);\n    free(calculated_grads_memory);\n    gpt2_free(&model);\n    cublasCheck(cublasDestroy(cublas_handle));\n\n    return 0;\n}"
  },
  {
    "path": "train_gpt2.c",
    "content": "/*\nThis file trains the GPT-2 model.\nThis version is the clean, minimal, reference. As such:\n- it runs on CPU.\n- it does not make the code too complex; it is readable.\n- it does not use any processor-specific instructions, intrinsics and such.\n- it _does_ use a few OpenMP pragmas because this is a large speedup at very low cost\nThere will be other versions of this code that specialize it and make it fast.\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <ctype.h>\n#include <stdint.h>\n#include <assert.h>\n#include <math.h>\n#include <time.h>\n#include <string.h>\n#include <unistd.h>\n#ifdef OMP\n#include <omp.h>\n#endif\n// our own utilities\n// defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck\n#include \"llmc/utils.h\"\n// defines: tokenizer_init, tokenizer_decode, tokenizer_free\n#include \"llmc/tokenizer.h\"\n// defines: dataloader_init, dataloader_reset, dataloader_next_batch, dataloader_free\n#include \"llmc/dataloader.h\"\n\n// ----------------------------------------------------------------------------\n// all the individual layers' forward and backward passes\n// B = batch_size, T = sequence_length, C = channels, V = vocab_size\n\nvoid encoder_forward(float* out,\n                   int* inp, float* wte, float* wpe,\n                   int B, int T, int C) {\n    // out is (B,T,C). At each position (b,t), a C-dimensional vector summarizing token & position\n    // inp is (B,T) of integers, holding the token ids at each (b,t) position\n    // wte is (V,C) of token embeddings, short for \"weight token embeddings\"\n    // wpe is (maxT,C) of position embeddings, short for \"weight positional embedding\"\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            // seek to the output position in out[b,t,:]\n            float* out_bt = out + b * T * C + t * C;\n            // get the index of the token at inp[b, t]\n            int ix = inp[b * T + t];\n            // seek to the position in wte corresponding to the token\n            float* wte_ix = wte + ix * C;\n            // seek to the position in wpe corresponding to the position\n            float* wpe_t = wpe + t * C;\n            // add the two vectors and store the result in out[b,t,:]\n            for (int i = 0; i < C; i++) {\n                out_bt[i] = wte_ix[i] + wpe_t[i];\n            }\n        }\n    }\n}\n\nvoid encoder_backward(float* dwte, float* dwpe,\n                      float* dout, int* inp,\n                      int B, int T, int C) {\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            float* dout_bt = dout + b * T * C + t * C;\n            int ix = inp[b * T + t];\n            float* dwte_ix = dwte + ix * C;\n            float* dwpe_t = dwpe + t * C;\n            for (int i = 0; i < C; i++) {\n                float d = dout_bt[i];\n                dwte_ix[i] += d;\n                dwpe_t[i] += d;\n            }\n        }\n    }\n}\n\nvoid layernorm_forward(float* out, float* mean, float* rstd,\n                       float* inp, float* weight, float* bias,\n                       int B, int T, int C) {\n    // reference: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html\n    // both inp and out are (B,T,C) of the activations\n    // mean and rstd are (B,T) buffers, to be used later in backward pass\n    // at each position (b,t) of the input, the C-dimensional vector\n    // of activations gets normalized, then scaled and shifted\n    float eps = 1e-5f;\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            // seek to the input position inp[b,t,:]\n            float* x = inp + b * T * C + t * C;\n            // calculate the mean\n            float m = 0.0f;\n            for (int i = 0; i < C; i++) {\n                m += x[i];\n            }\n            m = m/C;\n            // calculate the variance (without any bias correction)\n            float v = 0.0f;\n            for (int i = 0; i < C; i++) {\n                float xshift = x[i] - m;\n                v += xshift * xshift;\n            }\n            v = v/C;\n            // calculate the rstd (reciprocal standard deviation)\n            float s = 1.0f / sqrtf(v + eps);\n            // seek to the output position in out[b,t,:]\n            float* out_bt = out + b * T * C + t * C;\n            for (int i = 0; i < C; i++) {\n                float n = (s * (x[i] - m)); // normalize\n                float o = n * weight[i] + bias[i]; // scale and shift\n                out_bt[i] = o; // write\n            }\n            // cache the mean and rstd for the backward pass later\n            mean[b * T + t] = m;\n            rstd[b * T + t] = s;\n        }\n    }\n}\n\nvoid layernorm_backward(float* dinp, float* dweight, float* dbias,\n                        float* dout, float* inp, float* weight, float* mean, float* rstd,\n                        int B, int T, int C) {\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            float* dout_bt = dout + b * T * C + t * C;\n            float* inp_bt = inp + b * T * C + t * C;\n            float* dinp_bt = dinp + b * T * C + t * C;\n            float mean_bt = mean[b * T + t];\n            float rstd_bt = rstd[b * T + t];\n\n            // first: two reduce operations\n            float dnorm_mean = 0.0f;\n            float dnorm_norm_mean = 0.0f;\n            for (int i = 0; i < C; i++) {\n                float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;\n                float dnorm_i = weight[i] * dout_bt[i];\n                dnorm_mean += dnorm_i;\n                dnorm_norm_mean += dnorm_i * norm_bti;\n            }\n            dnorm_mean = dnorm_mean / C;\n            dnorm_norm_mean = dnorm_norm_mean / C;\n\n            // now iterate again and accumulate all the gradients\n            for (int i = 0; i < C; i++) {\n                float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;\n                float dnorm_i = weight[i] * dout_bt[i];\n                // gradient contribution to bias\n                dbias[i] += dout_bt[i];\n                // gradient contribution to weight\n                dweight[i] += norm_bti * dout_bt[i];\n                // gradient contribution to input\n                float dval = 0.0f;\n                dval += dnorm_i; // term 1\n                dval -= dnorm_mean; // term 2\n                dval -= norm_bti * dnorm_norm_mean; // term 3\n                dval *= rstd_bt; // final scale\n                dinp_bt[i] += dval;\n            }\n        }\n    }\n}\n\nvoid matmul_forward_naive(float* out,\n                         const float* inp, const float* weight, const float* bias,\n                         int B, int T, int C, int OC) {\n    // the most naive implementation of matrix multiplication\n    // this serves as an algorithmic reference, and as a fallback for\n    // unfriendly input shapes inside matmul_forward(), below.\n    #pragma omp parallel for collapse(2)\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            int bt = b * T + t;\n            for (int o = 0; o < OC; o++) {\n                float val = (bias != NULL) ? bias[o] : 0.0f;\n                for (int i = 0; i < C; i++) {\n                    val += inp[bt * C + i] * weight[o*C + i];\n                }\n                out[bt * OC + o] = val;\n            }\n        }\n    }\n}\n\nvoid matmul_forward(float* out,\n                    const float* inp, const float* weight, const float* bias,\n                    int B, int T, int C, int OC) {\n    // most of the running time is spent here and in matmul_backward\n    // therefore, the implementation below is very mildly optimized\n    // this function is otherwise identical to that of matmul_forward_naive()\n    // OC is short for \"output channels\"\n    // inp is (B,T,C), weight is (OC, C), bias is (OC)\n    // out will be (B,T,OC)\n\n    // make sure the tiled loop will be correct or fallback to naive version\n    const int LOOP_UNROLL = 8;\n    if (B*T % LOOP_UNROLL != 0) {\n        matmul_forward_naive(out, inp, weight, bias, B, T, C, OC);\n        return;\n    }\n\n    // collapse the B and T loops into one and turn it into a strided loop.\n    // then we can tile the inner loop, and reuse the loaded weight LOOP_UNROLL many times\n    #pragma omp parallel for\n    for (int obt = 0; obt < B * T; obt += LOOP_UNROLL) {\n        for (int o = 0; o < OC; o++) {\n            // we'll keep LOOP_UNROLL many results in registers\n            float result[LOOP_UNROLL];\n            // initialize the bias, if it exists\n            for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) {\n                result[ibt] = (bias != NULL) ? bias[o] : 0.0f;\n            }\n            // inner loops. Because we do LOOP_UNROLL steps of inner bt, we can cache\n            // the value of weight[i + o * C] and reuse it.\n            // we compile with -Ofast, so the compiler will turn the inner loop into FMAs\n            for (int i = 0; i < C; i++) {\n                float w = weight[i + o * C];\n                for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) {\n                    int bt = obt + ibt;\n                    result[ibt] += inp[bt * C + i] * w;\n                }\n            }\n            // write back results to main memory\n            for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) {\n                int bt = obt + ibt;\n                out[bt * OC + o] = result[ibt];\n            }\n        }\n    }\n}\n\nvoid matmul_backward(float* dinp, float* dweight, float* dbias,\n                     const float* dout, const float* inp, const float* weight,\n                     int B, int T, int C, int OC) {\n    // most of the running time is spent here and in matmul_forward\n    // this backward could be done in a single \"round\" of loops\n    // but that doesn't afford an efficient parallelization strategy\n\n    // backward into inp first, parallelize over B,T\n    #pragma omp parallel for collapse(2)\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            const float* dout_bt = dout + b * T * OC + t * OC;\n            float* dinp_bt = dinp + b * T * C + t * C;\n            for (int o = 0; o < OC; o++) {\n                const float* wrow = weight + o*C;\n                float d = dout_bt[o];\n                for (int i = 0; i < C; i++) {\n                    dinp_bt[i] += wrow[i] * d;\n                }\n            }\n        }\n    }\n    // backward into weight/bias, parallelize over output channels OC\n    #pragma omp parallel for\n    for (int o = 0; o < OC; o++) {\n        for (int b = 0; b < B; b++) {\n            for (int t = 0; t < T; t++) {\n                const float* dout_bt = dout + b * T * OC + t * OC;\n                const float* inp_bt = inp + b * T * C + t * C;\n                float* dwrow = dweight + o*C;\n                float d = dout_bt[o];\n                if (dbias != NULL) { dbias[o] += d; }\n                for (int i = 0; i < C; i++) {\n                    dwrow[i] += inp_bt[i] * d;\n                }\n            }\n        }\n    }\n}\n\nvoid attention_forward(float* out, float* preatt, float* att,\n                       float* inp,\n                       int B, int T, int C, int NH) {\n    // input is (B, T, 3C) holding the query, key, value (Q, K, V) vectors\n    // preatt, att are (B, NH, T, T). NH = number of heads, T = sequence length\n    // that holds the pre-attention and post-attention scores (used in backward)\n    // output is (B, T, C)\n    // attention is the only layer that mixes information across time\n    // every other operation is applied at every (b,t) position independently\n    // (and of course, no layer mixes information across batch)\n    int C3 = C*3;\n    int hs = C / NH; // head size\n    float scale = 1.0 / sqrtf(hs);\n\n    #pragma omp parallel for collapse(3)\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            for (int h = 0; h < NH; h++) {\n                float* query_t = inp + b * T * C3 + t * C3 + h * hs;\n                float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T;\n                float* att_bth = att + b*NH*T*T + h*T*T + t*T;\n\n                // pass 1: calculate query dot key and maxval\n                float maxval = -10000.0f; // TODO something better\n                for (int t2 = 0; t2 <= t; t2++) {\n                    float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key\n\n                    // (query_t) dot (key_t2)\n                    float val = 0.0f;\n                    for (int i = 0; i < hs; i++) {\n                        val += query_t[i] * key_t2[i];\n                    }\n                    val *= scale;\n                    if (val > maxval) {\n                        maxval = val;\n                    }\n\n                    preatt_bth[t2] = val;\n                }\n\n                // pass 2: calculate the exp and keep track of sum\n                // maxval is being calculated and subtracted only for numerical stability\n                float expsum = 0.0f;\n                for (int t2 = 0; t2 <= t; t2++) {\n                    float expv = expf(preatt_bth[t2] - maxval);\n                    expsum += expv;\n                    att_bth[t2] = expv;\n                }\n                float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum;\n\n                // pass 3: normalize to get the softmax\n                for (int t2 = 0; t2 < T; t2++) {\n                    if (t2 <= t) {\n                        att_bth[t2] *= expsum_inv;\n                    } else {\n                        // causal attention mask. not strictly necessary to set to zero here\n                        // only doing this explicitly for debugging and checking to PyTorch\n                        att_bth[t2] = 0.0f;\n                    }\n                }\n\n                // pass 4: accumulate weighted values into the output of attention\n                float* out_bth = out + b * T * C + t * C + h * hs;\n                for (int i = 0; i < hs; i++) { out_bth[i] = 0.0f; }\n                for (int t2 = 0; t2 <= t; t2++) {\n                    float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value\n                    float att_btht2 = att_bth[t2];\n                    for (int i = 0; i < hs; i++) {\n                        out_bth[i] += att_btht2 * value_t2[i];\n                    }\n                }\n            }\n        }\n    }\n}\n\nvoid attention_backward(float* dinp, float* dpreatt, float* datt,\n                        float* dout, float* inp, float* att,\n                        int B, int T, int C, int NH) {\n    // inp/dinp are (B, T, 3C) Q,K,V\n    // att/datt/dpreatt are (B, NH, T, T)\n    // dout is (B, T, C)\n    int C3 = C*3;\n    int hs = C / NH; // head size\n    float scale = 1.f / sqrtf(hs);\n\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            for (int h = 0; h < NH; h++) {\n                float* att_bth = att + b*NH*T*T + h*T*T + t*T;\n                float* datt_bth = datt + b*NH*T*T + h*T*T + t*T;\n                float* dpreatt_bth = dpreatt + b*NH*T*T + h*T*T + t*T;\n                float* dquery_t = dinp + b * T * C3 + t * C3 + h * hs;\n                float* query_t = inp + b * T * C3 + t * C3 + h * hs;\n\n                // backward pass 4, through the value accumulation\n                float* dout_bth = dout + b * T * C + t * C + h * hs;\n                for (int t2 = 0; t2 <= t; t2++) {\n                    float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value\n                    float* dvalue_t2 = dinp + b * T * C3 + t2 * C3 + h * hs + C*2;\n                    for (int i = 0; i < hs; i++) {\n                        // in the forward pass this was:\n                        // out_bth[i] += att_bth[t2] * value_t2[i];\n                        // so now we have:\n                        datt_bth[t2] += value_t2[i] * dout_bth[i];\n                        dvalue_t2[i] += att_bth[t2] * dout_bth[i];\n                    }\n                }\n\n                // backward pass 2 & 3, the softmax\n                // note that softmax (like e.g. tanh) doesn't need the input (preatt) to backward\n                for (int t2 = 0; t2 <= t; t2++) {\n                    for (int t3 = 0; t3 <= t; t3++) {\n                        float indicator = t2 == t3 ? 1.0f : 0.0f;\n                        float local_derivative = att_bth[t2] * (indicator - att_bth[t3]);\n                        dpreatt_bth[t3] += local_derivative * datt_bth[t2];\n                    }\n                }\n\n                // backward pass 1, the query @ key matmul\n                for (int t2 = 0; t2 <= t; t2++) {\n                    float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key\n                    float* dkey_t2 = dinp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key\n                    for (int i = 0; i < hs; i++) {\n                        // in the forward pass this was:\n                        // preatt_bth[t2] += (query_t[i] * key_t2[i]) * scale;\n                        // so now we have:\n                        dquery_t[i] += key_t2[i] * dpreatt_bth[t2] * scale;\n                        dkey_t2[i] += query_t[i] * dpreatt_bth[t2] * scale;\n                    }\n                }\n            }\n        }\n    }\n}\n\n#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI)\nvoid gelu_forward(float* out, float* inp, int N) {\n    // (approximate) GeLU elementwise non-linearity in the MLP block of Transformer\n    for (int i = 0; i < N; i++) {\n        float x = inp[i];\n        float cube = 0.044715f * x * x * x;\n        out[i] = 0.5f * x * (1.0f + tanhf(GELU_SCALING_FACTOR * (x + cube)));\n    }\n}\n\n// we want to use -Ofast optimization, but sadly GeLU breaks, so disable this flag just for it (#168)\n#pragma float_control(precise, on, push)\n#if defined(__GNUC__) && !defined(__clang__)\n__attribute__((optimize(\"no-finite-math-only\")))\n#endif\nvoid gelu_backward(float* dinp, float* inp, float* dout, int N) {\n    for (int i = 0; i < N; i++) {\n        float x = inp[i];\n        float cube = 0.044715f * x * x * x;\n        float tanh_arg = GELU_SCALING_FACTOR * (x + cube);\n        float tanh_out = tanhf(tanh_arg);\n        float coshf_out = coshf(tanh_arg);\n        float sech_out = 1.0f / (coshf_out * coshf_out);\n        float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);\n        dinp[i] += local_grad * dout[i];\n    }\n}\n#pragma float_control(pop)\n\nvoid residual_forward(float* out, float* inp1, float* inp2, int N) {\n    for (int i = 0; i < N; i++) {\n        out[i] = inp1[i] + inp2[i];\n    }\n}\n\nvoid residual_backward(float* dinp1, float* dinp2, float* dout, int N) {\n    for (int i = 0; i < N; i++) {\n        dinp1[i] += dout[i];\n        dinp2[i] += dout[i];\n    }\n}\n\nvoid softmax_forward(float* probs, float* logits, int B, int T, int V, int Vp) {\n    // output: probs are (B,T,Vp) of the probabilities (sums to 1.0 in each b,t position)\n    // input: logits is (B,T,Vp) of the unnormalized log probabilities\n    // Vp is the padded vocab size (for efficiency), V is the \"real\" vocab size\n    // example: Vp is 50304 and V is 50257\n    #pragma omp parallel for collapse(2)\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            // probs <- softmax(logits)\n            float* logits_bt = logits + b * T * Vp + t * Vp;\n            float* probs_bt = probs + b * T * Vp + t * Vp;\n\n            // maxval is only calculated and subtracted for numerical stability\n            float maxval = -10000.0f; // TODO something better\n            for (int i = 0; i < V; i++) {\n                if (logits_bt[i] > maxval) {\n                    maxval = logits_bt[i];\n                }\n            }\n            float sum = 0.0f;\n            for (int i = 0; i < V; i++) {\n                probs_bt[i] = expf(logits_bt[i] - maxval);\n                sum += probs_bt[i];\n            }\n            // note we only loop to V, leaving the padded dimensions\n            for (int i = 0; i < V; i++) {\n                probs_bt[i] /= sum;\n            }\n            // for extra super safety we may wish to include this too,\n            // forcing the probabilities here to be zero, but it shouldn't matter\n            for (int i = V; i < Vp; i++) {\n                probs_bt[i] = 0.0f;\n            }\n        }\n    }\n}\n\nvoid crossentropy_forward(float* losses,\n                          float* probs, int* targets,\n                          int B, int T, int Vp) {\n    // output: losses is (B,T) of the individual losses at each position\n    // input: probs are (B,T,Vp) of the probabilities\n    // input: targets is (B,T) of integers giving the correct index in logits\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            // loss = -log(probs[target])\n            float* probs_bt = probs + b * T * Vp + t * Vp;\n            int ix = targets[b * T + t];\n            losses[b * T + t] = -logf(probs_bt[ix]);\n        }\n    }\n}\n\nvoid crossentropy_softmax_backward(float* dlogits,\n                           float* dlosses, float* probs, int* targets,\n                           int B, int T, int V, int Vp) {\n    // backwards through both softmax and crossentropy\n    for (int b = 0; b < B; b++) {\n        for (int t = 0; t < T; t++) {\n            float* dlogits_bt = dlogits + b * T * Vp + t * Vp;\n            float* probs_bt = probs + b * T * Vp + t * Vp;\n            float dloss = dlosses[b * T + t];\n            int ix = targets[b * T + t];\n            // note we only loop to V, leaving the padded dimensions\n            // of dlogits untouched, so gradient there stays at zero\n            for (int i = 0; i < V; i++) {\n                float p = probs_bt[i];\n                float indicator = i == ix ? 1.0f : 0.0f;\n                dlogits_bt[i] += (p - indicator) * dloss;\n            }\n        }\n    }\n}\n\n// ----------------------------------------------------------------------------\n// GPT-2 model definition\n\ntypedef struct {\n    int max_seq_len; // max sequence length, e.g. 1024\n    int vocab_size; // vocab size, e.g. 50257\n    int padded_vocab_size; // padded to e.g. %128==0, 50304\n    int num_layers; // number of layers, e.g. 12\n    int num_heads; // number of heads in attention, e.g. 12\n    int channels; // number of channels, e.g. 768\n} GPT2Config;\n\n// the parameters of the model\n#define NUM_PARAMETER_TENSORS 16\ntypedef struct {\n    float* wte; // (V, C)\n    float* wpe; // (maxT, C)\n    float* ln1w; // (L, C)\n    float* ln1b; // (L, C)\n    float* qkvw; // (L, 3*C, C)\n    float* qkvb; // (L, 3*C)\n    float* attprojw; // (L, C, C)\n    float* attprojb; // (L, C)\n    float* ln2w; // (L, C)\n    float* ln2b; // (L, C)\n    float* fcw; // (L, 4*C, C)\n    float* fcb; // (L, 4*C)\n    float* fcprojw; // (L, C, 4*C)\n    float* fcprojb; // (L, C)\n    float* lnfw; // (C)\n    float* lnfb; // (C)\n} ParameterTensors;\n\nvoid fill_in_parameter_sizes(size_t* param_sizes, GPT2Config config) {\n    size_t Vp = config.padded_vocab_size;\n    size_t C = config.channels;\n    size_t maxT = config.max_seq_len;\n    size_t L = config.num_layers;\n    param_sizes[0] = Vp * C; // wte\n    param_sizes[1] = maxT * C; // wpe\n    param_sizes[2] = L * C; // ln1w\n    param_sizes[3] = L * C; // ln1b\n    param_sizes[4] = L * (3 * C) * C; // qkvw\n    param_sizes[5] = L * (3 * C); // qkvb\n    param_sizes[6] = L * C * C; // attprojw\n    param_sizes[7] = L * C; // attprojb\n    param_sizes[8] = L * C; // ln2w\n    param_sizes[9] = L * C; // ln2b\n    param_sizes[10] = L * (4 * C) * C; // fcw\n    param_sizes[11] = L * (4 * C); // fcb\n    param_sizes[12] = L * C * (4 * C); // fcprojw\n    param_sizes[13] = L * C; // fcprojb\n    param_sizes[14] = C; // lnfw\n    param_sizes[15] = C; // lnfb\n}\n\n// allocate memory for the parameters and point the individual tensors to the right places\nfloat* malloc_and_point_parameters(ParameterTensors* params, size_t* param_sizes) {\n    size_t num_parameters = 0;\n    for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {\n        num_parameters += param_sizes[i];\n    }\n    // malloc all parameters all at once\n    float* params_memory = (float*)mallocCheck(num_parameters * sizeof(float));\n    // assign all the tensors\n    float** ptrs[] = {\n        &params->wte, &params->wpe, &params->ln1w, &params->ln1b, &params->qkvw, &params->qkvb,\n        &params->attprojw, &params->attprojb, &params->ln2w, &params->ln2b, &params->fcw, &params->fcb,\n        &params->fcprojw, &params->fcprojb, &params->lnfw, &params->lnfb\n    };\n    float* params_memory_iterator = params_memory;\n    for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {\n        *(ptrs[i]) = params_memory_iterator;\n        params_memory_iterator += param_sizes[i];\n    }\n    return params_memory;\n}\n\n#define NUM_ACTIVATION_TENSORS 23\ntypedef struct {\n    float* encoded; // (B, T, C)\n    float* ln1; // (L, B, T, C)\n    float* ln1_mean; // (L, B, T)\n    float* ln1_rstd; // (L, B, T)\n    float* qkv; // (L, B, T, 3*C)\n    float* atty; // (L, B, T, C)\n    float* preatt; // (L, B, NH, T, T)\n    float* att; // (L, B, NH, T, T)\n    float* attproj; // (L, B, T, C)\n    float* residual2; // (L, B, T, C)\n    float* ln2; // (L, B, T, C)\n    float* ln2_mean; // (L, B, T)\n    float* ln2_rstd; // (L, B, T)\n    float* fch; // (L, B, T, 4*C)\n    float* fch_gelu; // (L, B, T, 4*C)\n    float* fcproj; // (L, B, T, C)\n    float* residual3; // (L, B, T, C)\n    float* lnf; // (B, T, C)\n    float* lnf_mean; // (B, T)\n    float* lnf_rstd; // (B, T)\n    float* logits; // (B, T, V)\n    float* probs; // (B, T, V)\n    float* losses; // (B, T)\n} ActivationTensors;\n\nvoid fill_in_activation_sizes(size_t* act_sizes, GPT2Config config, int B, int T) {\n    size_t C = config.channels;\n    size_t NH = config.num_heads;\n    size_t L = config.num_layers;\n    size_t Vp = config.padded_vocab_size;\n    act_sizes[0] = B * T * C; // encoded\n    act_sizes[1] = L * B * T * C; // ln1\n    act_sizes[2] = L * B * T; // ln1_mean\n    act_sizes[3] = L * B * T; // ln1_rstd\n    act_sizes[4] = L * B * T * 3 * C; // qkv\n    act_sizes[5] = L * B * T * C; // atty\n    act_sizes[6] = L * B * NH * T * T; // preatt\n    act_sizes[7] = L * B * NH * T * T; // att\n    act_sizes[8] = L * B * T * C; // attproj\n    act_sizes[9] = L * B * T * C; // residual2\n    act_sizes[10] = L * B * T * C; // ln2\n    act_sizes[11] = L * B * T; // ln2_mean\n    act_sizes[12] = L * B * T; // ln2_rstd\n    act_sizes[13] = L * B * T * 4 * C; // fch\n    act_sizes[14] = L * B * T * 4 * C; // fch_gelu\n    act_sizes[15] = L * B * T * C; // fcproj\n    act_sizes[16] = L * B * T * C; // residual3\n    act_sizes[17] = B * T * C; // lnf\n    act_sizes[18] = B * T; // lnf_mean\n    act_sizes[19] = B * T; // lnf_rstd\n    act_sizes[20] = B * T * Vp; // logits\n    act_sizes[21] = B * T * Vp; // probs\n    act_sizes[22] = B * T; // losses\n}\n\nfloat* malloc_and_point_activations(ActivationTensors* acts, size_t* act_sizes) {\n    size_t num_activations = 0;\n    for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {\n        num_activations += act_sizes[i];\n    }\n    float* acts_memory = (float*)mallocCheck(num_activations * sizeof(float));\n    float** ptrs[] = {\n        &acts->encoded, &acts->ln1, &acts->ln1_mean, &acts->ln1_rstd, &acts->qkv, &acts->atty,\n        &acts->preatt, &acts->att, &acts->attproj, &acts->residual2, &acts->ln2, &acts->ln2_mean,\n        &acts->ln2_rstd, &acts->fch, &acts->fch_gelu, &acts->fcproj, &acts->residual3, &acts->lnf,\n        &acts->lnf_mean, &acts->lnf_rstd, &acts->logits, &acts->probs, &acts->losses\n    };\n    float* acts_memory_iterator = acts_memory;\n    for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {\n        *(ptrs[i]) = acts_memory_iterator;\n        acts_memory_iterator += act_sizes[i];\n    }\n    return acts_memory;\n}\n\ntypedef struct {\n    GPT2Config config;\n    // the weights (parameters) of the model, and their sizes\n    ParameterTensors params;\n    size_t param_sizes[NUM_PARAMETER_TENSORS];\n    float* params_memory;\n    size_t num_parameters;\n    // gradients of the weights\n    ParameterTensors grads;\n    float* grads_memory;\n    // buffers for the AdamW optimizer\n    float* m_memory;\n    float* v_memory;\n    // the activations of the model, and their sizes\n    ActivationTensors acts;\n    size_t act_sizes[NUM_ACTIVATION_TENSORS];\n    float* acts_memory;\n    size_t num_activations;\n    // gradients of the activations\n    ActivationTensors grads_acts;\n    float* grads_acts_memory;\n    // other run state configuration\n    int batch_size; // the batch size (B) of current forward pass\n    int seq_len; // the sequence length (T) of current forward pass\n    int* inputs; // the input tokens for the current forward pass\n    int* targets; // the target tokens for the current forward pass\n    float mean_loss; // after a forward pass with targets, will be populated with the mean loss\n} GPT2;\n\nvoid gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {\n\n    // read in model from a checkpoint file\n    FILE *model_file = fopenCheck(checkpoint_path, \"rb\");\n    int model_header[256];\n    freadCheck(model_header, sizeof(int), 256, model_file);\n    if (model_header[0] != 20240326) { printf(\"Bad magic model file\\n\"); exit(1); }\n    if (model_header[1] != 3) {\n        printf(\"Bad version in model file\\n\");\n        printf(\"---> HINT: try to re-run `python train_gpt2.py`\\n\");\n        exit(1);\n    }\n\n    // read in hyperparameters\n    size_t maxT, V, Vp, L, NH, C; // size_t to prevent int overflow\n    model->config.max_seq_len = maxT = model_header[2];\n    model->config.vocab_size = V = model_header[3];\n    model->config.num_layers = L = model_header[4];\n    model->config.num_heads = NH = model_header[5];\n    model->config.channels = C = model_header[6];\n    model->config.padded_vocab_size = Vp = model_header[7];\n    printf(\"[GPT-2]\\n\");\n    printf(\"max_seq_len: %zu\\n\", maxT);\n    printf(\"vocab_size: %zu\\n\", V);\n    printf(\"padded_vocab_size: %zu\\n\", Vp);\n    printf(\"num_layers: %zu\\n\", L);\n    printf(\"num_heads: %zu\\n\", NH);\n    printf(\"channels: %zu\\n\", C);\n\n    // allocate space for all the parameters and read them in\n    fill_in_parameter_sizes(model->param_sizes,  model->config);\n\n    // count the number of parameters\n    size_t num_parameters = 0;\n    for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {\n        num_parameters += model->param_sizes[i];\n    }\n    printf(\"num_parameters: %zu\\n\", num_parameters);\n    model->num_parameters = num_parameters;\n\n    // read in all the parameters from file\n    model->params_memory = malloc_and_point_parameters(&model->params, model->param_sizes);\n    freadCheck(model->params_memory, sizeof(float), num_parameters, model_file);\n    fcloseCheck(model_file);\n\n    // other inits\n    model->acts_memory = NULL;\n    model->grads_memory = NULL;\n    model->m_memory = NULL;\n    model->v_memory = NULL;\n    model->grads_acts_memory = NULL;\n    model->inputs = NULL;\n    model->targets = NULL;\n    model->batch_size = 0;\n    model->seq_len = 0;\n    model->mean_loss = -1.0f; // -1.0f will designate no loss\n}\n\nvoid gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {\n    // targets are optional and could be NULL\n\n    // ensure the model was initialized or error out\n    if (model->params_memory == NULL) {\n        printf(\"Error: model was not initialized properly.\\n\");\n        exit(1);\n    }\n\n    // convenience parameters (size_t to help prevent int overflow)\n    size_t V = model->config.vocab_size;\n    size_t Vp = model->config.padded_vocab_size;\n    size_t L = model->config.num_layers;\n    size_t NH = model->config.num_heads;\n    size_t C = model->config.channels;\n\n    // validate inputs, all indices must be in the range [0, V)\n    for(int i = 0; i < B * T; i++) {\n        assert(0 <= inputs[i] && inputs[i] < V);\n        if (targets != NULL) {\n            assert(0 <= targets[i] && targets[i] < V);\n        }\n    }\n\n    // allocate space for all the activations if needed (done here, lazily)\n    if(model->acts_memory == NULL) {\n        // record the current B,T as well\n        model->batch_size = B;\n        model->seq_len = T;\n        // and now allocate the space\n        fill_in_activation_sizes(model->act_sizes, model->config, B, T);\n        size_t num_activations = 0;\n        for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {\n            num_activations += model->act_sizes[i];\n        }\n        printf(\"num_activations: %zu\\n\", num_activations);\n        model->num_activations = num_activations;\n        model->acts_memory = malloc_and_point_activations(&model->acts, model->act_sizes);\n        // also create memory for caching inputs and targets\n        model->inputs = (int*)mallocCheck(B * T * sizeof(int));\n        model->targets = (int*)mallocCheck(B * T * sizeof(int)); // might be unused if we never have targets but it's small\n    } else {\n        // validate B,T is consistent with how we've allocated the memory before\n        // in principle we could get more clever here in the future, for now this is safest\n        if (B != model->batch_size || T != model->seq_len) {\n            printf(\"Model: B=%d T=%d, Desired: B=%d T=%d\\n\", model->batch_size, model->seq_len, (int)B, (int)T);\n            exit(EXIT_FAILURE);\n        }\n    }\n\n    // cache the inputs/targets\n    memcpy(model->inputs, inputs, B * T * sizeof(int));\n    if (targets != NULL) {\n        memcpy(model->targets, targets, B * T * sizeof(int));\n    }\n\n    // forward pass\n    ParameterTensors params = model->params; // for brevity\n    ActivationTensors acts = model->acts;\n    float* residual;\n    encoder_forward(acts.encoded, inputs, params.wte, params.wpe, B, T, C); // encoding goes into residual[0]\n    for (int l = 0; l < L; l++) {\n\n        residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;\n\n        // get the pointers of the weights for this layer\n        float* l_ln1w = params.ln1w + l * C;\n        float* l_ln1b = params.ln1b + l * C;\n        float* l_qkvw = params.qkvw + l * 3*C * C;\n        float* l_qkvb = params.qkvb + l * 3*C;\n        float* l_attprojw = params.attprojw + l * C * C;\n        float* l_attprojb = params.attprojb + l * C;\n        float* l_ln2w = params.ln2w + l * C;\n        float* l_ln2b = params.ln2b + l * C;\n        float* l_fcw = params.fcw + l * 4*C * C;\n        float* l_fcb = params.fcb + l * 4*C;\n        float* l_fcprojw = params.fcprojw + l * C * 4*C;\n        float* l_fcprojb = params.fcprojb + l * C;\n\n        // get the pointers of the activations for this layer\n        float* l_ln1 = acts.ln1 + l * B * T * C;\n        float* l_ln1_mean = acts.ln1_mean + l * B * T;\n        float* l_ln1_rstd = acts.ln1_rstd + l * B * T;\n        float* l_qkv = acts.qkv + l * B * T * 3*C;\n        float* l_atty = acts.atty + l * B * T * C;\n        float* l_preatt = acts.preatt + l * B * NH * T * T;\n        float* l_att = acts.att + l * B * NH * T * T;\n        float* l_attproj = acts.attproj + l * B * T * C;\n        float* l_residual2 = acts.residual2 + l * B * T * C;\n        float* l_ln2 = acts.ln2 + l * B * T * C;\n        float* l_ln2_mean = acts.ln2_mean + l * B * T;\n        float* l_ln2_rstd = acts.ln2_rstd + l * B * T;\n        float* l_fch = acts.fch + l * B * T * 4*C;\n        float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C;\n        float* l_fcproj = acts.fcproj + l * B * T * C;\n        float* l_residual3 = acts.residual3 + l * B * T * C;\n\n        // now do the forward pass\n        layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C);\n        matmul_forward(l_qkv, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C);\n        attention_forward(l_atty, l_preatt, l_att, l_qkv, B, T, C, NH);\n        matmul_forward(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C);\n        residual_forward(l_residual2, residual, l_attproj, B*T*C);\n        layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C);\n        matmul_forward(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C);\n        gelu_forward(l_fch_gelu, l_fch, B*T*4*C);\n        matmul_forward(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C);\n        residual_forward(l_residual3, l_residual2, l_fcproj, B*T*C);\n    }\n    residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3\n    layernorm_forward(acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C);\n    matmul_forward(acts.logits, acts.lnf, params.wte, NULL, B, T, C, Vp);\n    softmax_forward(acts.probs, acts.logits, B, T, V, Vp);\n\n    // also forward the cross-entropy loss function if we have the targets\n    if (targets != NULL) {\n        crossentropy_forward(model->acts.losses, model->acts.probs, targets, B, T, Vp);\n        // for convenience also evaluate the mean loss\n        float mean_loss = 0.0f;\n        for (int i=0; i<B*T; i++) { mean_loss += model->acts.losses[i]; }\n        mean_loss /= B*T;\n        model->mean_loss = mean_loss;\n    } else {\n        // if we don't have targets, we don't have a loss\n        model->mean_loss = -1.0f;\n    }\n}\n\nvoid gpt2_zero_grad(GPT2 *model) {\n    if(model->grads_memory != NULL) { memset(model->grads_memory, 0, model->num_parameters * sizeof(float)); }\n    if(model->grads_acts_memory != NULL) { memset(model->grads_acts_memory, 0, model->num_activations * sizeof(float)); }\n}\n\nvoid gpt2_backward(GPT2 *model) {\n\n    // double check we forwarded previously, with targets\n    if (model->mean_loss == -1.0f) {\n        printf(\"Error: must forward with targets before backward\\n\");\n        exit(1);\n    }\n\n    // lazily allocate the memory for gradients of the weights and activations, if needed\n    if (model->grads_memory == NULL) {\n        model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_sizes);\n        model->grads_acts_memory = malloc_and_point_activations(&model->grads_acts, model->act_sizes);\n        gpt2_zero_grad(model);\n    }\n\n    // convenience shortcuts (and size_t to help prevent int overflow)\n    size_t B = model->batch_size;\n    size_t T = model->seq_len;\n    size_t V = model->config.vocab_size;\n    size_t Vp = model->config.padded_vocab_size;\n    size_t L = model->config.num_layers;\n    size_t NH = model->config.num_heads;\n    size_t C = model->config.channels;\n\n    // backward pass: go in the reverse order of the forward pass, and call backward() functions\n    ParameterTensors params = model->params; // for brevity\n    ParameterTensors grads = model->grads;\n    ActivationTensors acts = model->acts;\n    ActivationTensors grads_acts = model->grads_acts;\n\n    // we kick off the chain rule by filling in dlosses with 1.0f/(B*T)\n    // technically this is a small, inline backward() pass of calculating\n    // total, final loss as the mean over all losses over all (B,T) positions in the batch\n    float dloss_mean = 1.0f / (B*T);\n    for (int i = 0; i < B*T; i++) { grads_acts.losses[i] = dloss_mean; }\n\n    crossentropy_softmax_backward(grads_acts.logits, grads_acts.losses, acts.probs, model->targets, B, T, V, Vp);\n    matmul_backward(grads_acts.lnf, grads.wte, NULL, grads_acts.logits, acts.lnf, params.wte, B, T, C, Vp);\n    float* residual = acts.residual3 + (L-1) * B * T * C; // last layer's residual\n    float* dresidual = grads_acts.residual3 + (L-1) * B * T * C; // write to last layer's residual\n    layernorm_backward(dresidual, grads.lnfw, grads.lnfb, grads_acts.lnf, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C);\n\n    for (int l = L-1; l >= 0; l--) {\n\n        residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;\n        dresidual = l == 0 ? grads_acts.encoded : grads_acts.residual3 + (l-1) * B * T * C;\n\n        // get the pointers of the weights for this layer\n        float* l_ln1w = params.ln1w + l * C;\n        float* l_qkvw = params.qkvw + l * 3*C * C;\n        float* l_attprojw = params.attprojw + l * C * C;\n        float* l_ln2w = params.ln2w + l * C;\n        float* l_fcw = params.fcw + l * 4*C * C;\n        float* l_fcprojw = params.fcprojw + l * C * 4*C;\n        // get the pointers of the gradients of the weights for this layer\n        float* dl_ln1w = grads.ln1w + l * C;\n        float* dl_ln1b = grads.ln1b + l * C;\n        float* dl_qkvw = grads.qkvw + l * 3*C * C;\n        float* dl_qkvb = grads.qkvb + l * 3*C;\n        float* dl_attprojw = grads.attprojw + l * C * C;\n        float* dl_attprojb = grads.attprojb + l * C;\n        float* dl_ln2w = grads.ln2w + l * C;\n        float* dl_ln2b = grads.ln2b + l * C;\n        float* dl_fcw = grads.fcw + l * 4*C * C;\n        float* dl_fcb = grads.fcb + l * 4*C;\n        float* dl_fcprojw = grads.fcprojw + l * C * 4*C;\n        float* dl_fcprojb = grads.fcprojb + l * C;\n        // get the pointers of the activations for this layer\n        float* l_ln1 = acts.ln1 + l * B * T * C;\n        float* l_ln1_mean = acts.ln1_mean + l * B * T;\n        float* l_ln1_rstd = acts.ln1_rstd + l * B * T;\n        float* l_qkv = acts.qkv + l * B * T * 3*C;\n        float* l_atty = acts.atty + l * B * T * C;\n        float* l_att = acts.att + l * B * NH * T * T;\n        float* l_residual2 = acts.residual2 + l * B * T * C;\n        float* l_ln2 = acts.ln2 + l * B * T * C;\n        float* l_ln2_mean = acts.ln2_mean + l * B * T;\n        float* l_ln2_rstd = acts.ln2_rstd + l * B * T;\n        float* l_fch = acts.fch + l * B * T * 4*C;\n        float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C;\n        // get the pointers of the gradients of the activations for this layer\n        float* dl_ln1 = grads_acts.ln1 + l * B * T * C;\n        float* dl_qkv = grads_acts.qkv + l * B * T * 3*C;\n        float* dl_atty = grads_acts.atty + l * B * T * C;\n        float* dl_preatt = grads_acts.preatt + l * B * NH * T * T;\n        float* dl_att = grads_acts.att + l * B * NH * T * T;\n        float* dl_attproj = grads_acts.attproj + l * B * T * C;\n        float* dl_residual2 = grads_acts.residual2 + l * B * T * C;\n        float* dl_ln2 = grads_acts.ln2 + l * B * T * C;\n        float* dl_fch = grads_acts.fch + l * B * T * 4*C;\n        float* dl_fch_gelu = grads_acts.fch_gelu + l * B * T * 4*C;\n        float* dl_fcproj = grads_acts.fcproj + l * B * T * C;\n        float* dl_residual3 = grads_acts.residual3 + l * B * T * C;\n\n        // backprop this layer\n        residual_backward(dl_residual2, dl_fcproj, dl_residual3, B*T*C);\n        matmul_backward(dl_fch_gelu, dl_fcprojw, dl_fcprojb, dl_fcproj, l_fch_gelu, l_fcprojw, B, T, 4*C, C);\n        gelu_backward(dl_fch, l_fch, dl_fch_gelu, B*T*4*C);\n        matmul_backward(dl_ln2, dl_fcw, dl_fcb, dl_fch, l_ln2, l_fcw, B, T, C, 4*C);\n        layernorm_backward(dl_residual2, dl_ln2w, dl_ln2b, dl_ln2, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C);\n        residual_backward(dresidual, dl_attproj, dl_residual2, B*T*C);\n        matmul_backward(dl_atty, dl_attprojw, dl_attprojb, dl_attproj, l_atty, l_attprojw, B, T, C, C);\n        attention_backward(dl_qkv, dl_preatt, dl_att, dl_atty, l_qkv, l_att, B, T, C, NH);\n        matmul_backward(dl_ln1, dl_qkvw, dl_qkvb, dl_qkv, l_ln1, l_qkvw, B, T, C, 3*C);\n        layernorm_backward(dresidual, dl_ln1w, dl_ln1b, dl_ln1, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C);\n    }\n    encoder_backward(grads.wte, grads.wpe, grads_acts.encoded, model->inputs, B, T, C);\n}\n\nvoid gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, int t) {\n    // reference: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html\n\n    // lazily allocate the memory for m_memory and v_memory\n    if (model->m_memory == NULL) {\n        model->m_memory = (float*)calloc(model->num_parameters, sizeof(float));\n        model->v_memory = (float*)calloc(model->num_parameters, sizeof(float));\n    }\n\n    for (size_t i = 0; i < model->num_parameters; i++) {\n        float param = model->params_memory[i];\n        float grad = model->grads_memory[i];\n\n        // update the first moment (momentum)\n        float m = beta1 * model->m_memory[i] + (1.0f - beta1) * grad;\n        // update the second moment (RMSprop)\n        float v = beta2 * model->v_memory[i] + (1.0f - beta2) * grad * grad;\n        // bias-correct both moments\n        float m_hat = m / (1.0f - powf(beta1, t));\n        float v_hat = v / (1.0f - powf(beta2, t));\n\n        // update\n        model->m_memory[i] = m;\n        model->v_memory[i] = v;\n        model->params_memory[i] -= learning_rate * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * param);\n    }\n}\n\nvoid gpt2_free(GPT2 *model) {\n    free(model->params_memory);\n    free(model->grads_memory);\n    free(model->m_memory);\n    free(model->v_memory);\n    free(model->acts_memory);\n    free(model->grads_acts_memory);\n    free(model->inputs);\n    free(model->targets);\n}\n\n#ifndef TESTING\n// if we are TESTING (see test_gpt2.c), we'll skip the int main below\n// ----------------------------------------------------------------------------\n// sampler\n\nunsigned int random_u32(uint64_t *state) {\n    // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A\n    *state ^= *state >> 12;\n    *state ^= *state << 25;\n    *state ^= *state >> 27;\n    return (*state * 0x2545F4914F6CDD1Dull) >> 32;\n}\nfloat random_f32(uint64_t *state) { // random float32 in [0,1)\n    return (random_u32(state) >> 8) / 16777216.0f;\n}\n\nint sample_mult(float* probabilities, int n, float coin) {\n    // sample index from probabilities (they must sum to 1!)\n    // coin is a random number in [0, 1), usually from random_f32()\n    float cdf = 0.0f;\n    for (int i = 0; i < n; i++) {\n        cdf += probabilities[i];\n        if (coin < cdf) {\n            return i;\n        }\n    }\n    return n - 1; // in case of rounding errors\n}\n\n// ----------------------------------------------------------------------------\n// main training loop\nint main() {\n\n    // build the GPT-2 model from a checkpoint\n    GPT2 model;\n    gpt2_build_from_checkpoint(&model, \"gpt2_124M.bin\");\n\n    // build the DataLoaders from tokens files. for now use tiny_shakespeare if available, else tiny_stories\n    const char* tiny_stories_train = \"dev/data/tinystories/TinyStories_train.bin\";\n    const char* tiny_stories_val = \"dev/data/tinystories/TinyStories_val.bin\";\n    const char* tiny_shakespeare_train = \"dev/data/tinyshakespeare/tiny_shakespeare_train.bin\";\n    const char* tiny_shakespeare_val = \"dev/data/tinyshakespeare/tiny_shakespeare_val.bin\";\n    const char* train_tokens = access(tiny_shakespeare_train, F_OK) != -1 ? tiny_shakespeare_train : tiny_stories_train;\n    const char* val_tokens = access(tiny_shakespeare_val, F_OK) != -1 ? tiny_shakespeare_val : tiny_stories_val;\n    int B = 4; // batch size 4 (i.e. 4 independent token sequences will be trained on)\n    int T = 64; // sequence length 64 (i.e. each sequence is 64 tokens long). must be <= maxT, which is 1024 for GPT-2\n    DataLoader train_loader, val_loader;\n    dataloader_init(&train_loader, train_tokens, B, T, 0, 1, 1);\n    dataloader_init(&val_loader, val_tokens, B, T, 0, 1, 0);\n    printf(\"train dataset num_batches: %zu\\n\", train_loader.num_tokens / (B*T));\n    printf(\"val dataset num_batches: %zu\\n\", val_loader.num_tokens / (B*T));\n    int val_num_batches = 5;\n\n    // build the Tokenizer\n    Tokenizer tokenizer;\n    tokenizer_init(&tokenizer, \"gpt2_tokenizer.bin\");\n\n    // some memory for generating samples from the model\n    uint64_t rng_state = 1337;\n    int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int));\n    const int genT = 64; // number of steps of inference we will do\n\n    // train\n    struct timespec start, end;\n    for (int step = 0; step <= 40; step++) {\n\n        // once in a while estimate the validation loss\n        if (step % 10 == 0) {\n            float val_loss = 0.0f;\n            dataloader_reset(&val_loader);\n            for (int i = 0; i < val_num_batches; i++) {\n                dataloader_next_batch(&val_loader);\n                gpt2_forward(&model, val_loader.inputs, val_loader.targets, B, T);\n                val_loss += model.mean_loss;\n            }\n            val_loss /= val_num_batches;\n            printf(\"val loss %f\\n\", val_loss);\n        }\n\n        // once in a while do model inference to print generated text\n        if (step > 0 && step % 20 == 0) {\n            // fill up gen_tokens with the GPT2_EOT, which kicks off the generation\n            for(int i = 0; i < B * T; ++i) {\n                gen_tokens[i] = tokenizer.eot_token;\n            }\n            // now sample from the model autoregressively\n            printf(\"generating:\\n---\\n\");\n            for (int t = 1; t < genT; t++) {\n                // note that inference is very wasteful here because for each token\n                // we re-calculate the forward pass for all of (B,T) positions from scratch\n                // but the inference here is just for sanity checking anyway\n                // and we can maybe optimize a bit more later, with careful tests\n                gpt2_forward(&model, gen_tokens, NULL, B, T);\n                // furthermore, below we're only using b=0 (i.e. the first row) of all B rows\n                // we're in principle running B \"inference streams\" in parallel here\n                // but only using position 0\n                // get the Vp-dimensional vector probs[0, t-1, :]\n                float* probs = model.acts.probs + (t-1) * model.config.padded_vocab_size;\n                float coin = random_f32(&rng_state);\n                // note we're only sampling from the first V elements, ignoring padding\n                // (the probabilities in the padded region should be zero anyway)\n                int next_token = sample_mult(probs, model.config.vocab_size, coin);\n                gen_tokens[t] = next_token;\n                // print the generated token, either using the Tokenizer or a fallback\n                if (tokenizer.init_ok) {\n                    const char* token_str = tokenizer_decode(&tokenizer, next_token);\n                    safe_printf(token_str);\n                } else {\n                    // fall back to printing the token id\n                    printf(\"%d \", next_token);\n                }\n                fflush(stdout);\n            }\n            printf(\"\\n---\\n\");\n        }\n\n        // do a training step\n        clock_gettime(CLOCK_MONOTONIC, &start);\n        dataloader_next_batch(&train_loader);\n        gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T);\n        gpt2_zero_grad(&model);\n        gpt2_backward(&model);\n        gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, step+1);\n        clock_gettime(CLOCK_MONOTONIC, &end);\n        double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;\n        printf(\"step %d: train loss %f (took %f ms)\\n\", step, model.mean_loss, time_elapsed_s * 1000);\n    }\n\n    // free\n    dataloader_free(&train_loader);\n    dataloader_free(&val_loader);\n    tokenizer_free(&tokenizer);\n    gpt2_free(&model);\n    free(gen_tokens);\n    return 0;\n}\n#endif\n"
  },
  {
    "path": "train_gpt2.cu",
    "content": "/*\nGPT-2 Transformer Neural Net training loop. See README.md for usage.\n*/\n#include <unistd.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include <stdarg.h>\n#include <string>\n#include <string_view>\n#include <sys/stat.h>\n#include <sys/types.h>\n// ----------- CPU utilities -----------\n// defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck\n// defines: create_dir_if_not_exists, find_max_step, ends_with_bin\n#include \"llmc/utils.h\"\n// defines: tokenizer_init, tokenizer_decode, tokenizer_free\n#include \"llmc/tokenizer.h\"\n// defines: dataloader_init, dataloader_reset, dataloader_next_batch, dataloader_free\n// defines: evalloader_init, evalloader_reset, evalloader_next_batch, evalloader_free\n#include \"llmc/dataloader.h\"\n// defines: manual_seed, normal_ (same as torch.manual_seed and torch.normal)\n#include \"llmc/rand.h\"\n// defines: lr_scheduler_init, get_learning_rate\n#include \"llmc/schedulers.h\"\n// defines: sample_softmax, random_f32\n#include \"llmc/sampler.h\"\n// defines: logger_init, logger_log_eval, logger_log_val, logger_log_train\n#include \"llmc/logger.h\"\n// defines: get_flops_promised\n#include \"llmc/mfu.h\"\n// defines: OutlierDetector, init_detector, update_detector\n#include \"llmc/outlier_detector.h\"\n// ----------- GPU utilities -----------\n// defines:\n// WARP_SIZE, MAX_1024_THREADS_BLOCKS, CEIL_DIV, cudaCheck, PRECISION_MODE\n// NVTX_RANGE_FN\n#include \"llmc/cuda_common.h\"\n// defines:\n// Packed128, f128, x128\n// warpReduceSum, warpReduceMax, blockReduce, copy_and_cast_kernel, cudaMallocConditionallyManaged\n#include \"llmc/cuda_utils.cuh\"\n// defines: CUBLAS_LOWP, cublasCheck, cublaslt_workspace_size, cublaslt_workspace\n// defines: cublas_compute, cublaslt_handle, cublas_handle\n#include \"llmc/cublas_common.h\"\n// ----------- Layer implementations in CUDA -----------\n// defines: encoder_forward, encoder_backward\n#include \"llmc/encoder.cuh\"\n// defines: layernorm_forward, residual_forward, fused_residual_forward5, layernorm_backward\n#include \"llmc/layernorm.cuh\"\n// defines: matmul_cublaslt, matmul_forward, matmul_backward, gelu_forward, gelu_backward_inplace\n#include \"llmc/matmul.cuh\"\n#ifdef ENABLE_CUDNN\n// defines: create_cudnn, destroy_cudnn, attention_forward_cudnn, attention_backward_cudnn\n#include \"llmc/cudnn_att.h\"\n#else\n// defines: attention_forward, attention_backward\n#include \"llmc/attention.cuh\"\n#endif\n// defines: fused_classifier\n#include \"llmc/fused_classifier.cuh\"\n// defines: adamw_kernel3\n#include \"llmc/adamw.cuh\"\n// defines: global_norm_squared\n#include \"llmc/global_norm.cuh\"\n// ----------- Multi-GPU support -----------\n// defines: ncclFloatX, ncclCheck, MultiGpuConfig, ShardInfo\n// defines: printf0, multi_gpu_config\n// defines: multi_gpu_config_init, multi_gpu_config_free\n// defines: set_zero_configs, multi_gpu_cpu_float_sum, multi_gpu_barrier\n// defines: multi_gpu_get_shard_offset, multi_gpu_async_reduce_gradient\n#include \"llmc/zero.cuh\"\n\n// ----------------------------------------------------------------------------\n// global vars for I/O\nchar filename_buffer[512];\n\n// ----------------------------------------------------------------------------\n// global vars containing information about the GPU this process is running on\ncudaDeviceProp deviceProp; // fills in common_start()\ncudaStream_t main_stream;\n// buffer size to use for device <-> disk io\nconstexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024;\n\n// ----------------------------------------------------------------------------\n// GPT-2 model definition\n\ntypedef struct {\n    int max_seq_len; // max sequence length, e.g. 1024\n    int vocab_size; // vocab size, e.g. 50257\n    int padded_vocab_size; // padded to e.g. %128==0, 50304\n    int num_layers; // number of layers, e.g. 12\n    int num_heads; // number of heads in attention, e.g. 12\n    int channels; // number of channels, e.g. 768\n} GPT2Config;\n\n// the parameters of the model\nconstexpr const int NUM_PARAMETER_TENSORS = 16;\ntypedef struct {\n    floatX* wte; // (V, C)\n    floatX* wpe; // (maxT, C)\n    floatX* ln1w; // (L, C)\n    floatX* ln1b; // (L, C)\n    floatX* qkvw; // (L, 3*C, C)\n    floatX* qkvb; // (L, 3*C)\n    floatX* attprojw; // (L, C, C)\n    floatX* attprojb; // (L, C)\n    floatX* ln2w; // (L, C)\n    floatX* ln2b; // (L, C)\n    floatX* fcw; // (L, 4*C, C)\n    floatX* fcb; // (L, 4*C)\n    floatX* fcprojw; // (L, C, 4*C)\n    floatX* fcprojb; // (L, C)\n    floatX* lnfw; // (C)\n    floatX* lnfb; // (C)\n} ParameterTensors;\nstatic_assert(sizeof(ParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), \"Inconsistent sizes!\");\n\nvoid fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Config config) {\n    size_t Vp = config.padded_vocab_size;\n    size_t C = config.channels;\n    size_t maxT = config.max_seq_len;\n    size_t L = config.num_layers;\n    param_sizes[0] = Vp * C; // wte\n    param_sizes[1] = maxT * C; // wpe\n    param_sizes[2] = L * C; // ln1w\n    param_sizes[3] = L * C; // ln1b\n    param_sizes[4] = L * (3 * C) * C; // qkvw\n    param_sizes[5] = L * (3 * C); // qkvb\n    param_sizes[6] = L * C * C; // attprojw\n    param_sizes[7] = L * C; // attprojb\n    param_sizes[8] = L * C; // ln2w\n    param_sizes[9] = L * C; // ln2b\n    param_sizes[10] = L * (4 * C) * C; // fcw\n    param_sizes[11] = L * (4 * C); // fcb\n    param_sizes[12] = L * C * (4 * C); // fcprojw\n    param_sizes[13] = L * C; // fcprojb\n    param_sizes[14] = C; // lnfw\n    param_sizes[15] = C; // lnfb\n\n    // populate the parameter sizes in bytes (all the same for now, keeping for future use)\n    for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {\n        param_sizeof[i] = sizeof(floatX);\n    }\n}\n\n// allocate memory for the parameters and point the individual tensors to the right places\nvoid* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elements, size_t *param_sizeof) {\n    // calculate the total number of parameters and bytes across all tensors\n    size_t num_parameters_bytes = 0;\n    for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {\n        num_parameters_bytes += param_elements[i] * param_sizeof[i];\n    }\n    // malloc all parameters all at once on the device\n    void* params_memory;\n    cudaCheck(cudaMalloc((void**)&params_memory, num_parameters_bytes));\n    // assign all the tensors their place in the array\n    floatX** ptrs[] = {\n        &params->wte, &params->wpe, &params->ln1w, &params->ln1b, &params->qkvw, &params->qkvb,\n        &params->attprojw, &params->attprojb, &params->ln2w, &params->ln2b, &params->fcw, &params->fcb,\n        &params->fcprojw, &params->fcprojb, &params->lnfw, &params->lnfb\n    };\n    char* params_memory_iterator = (char*)params_memory;\n    for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {\n        *(ptrs[i]) = (floatX*)params_memory_iterator;\n        params_memory_iterator += param_elements[i] * param_sizeof[i];\n    }\n    return params_memory;\n}\n\nconstexpr int NUM_ACTIVATION_TENSORS = 21;\ntypedef struct {\n    floatX* encoded; // (B, T, C)\n    floatX* ln1; // (L, B, T, C)\n    float* ln1_mean; // (L, B, T)\n    float* ln1_rstd; // (L, B, T)\n    floatX* atty; // (L, B, T, C)\n    // cuDNN saves only some statistics information\n#if ENABLE_CUDNN\n    float* att;  // (L, B, NH, T)\n#else\n    floatX* att; // (L, B, NH, T, T)\n#endif\n\n    floatX* residual2; // (L, B, T, C)\n    floatX* ln2; // (L, B, T, C)\n    float* ln2_mean; // (L, B, T)\n    float* ln2_rstd; // (L, B, T)\n    floatX* fch; // (L, B, T, 4*C)\n    floatX* fch_gelu; // (L, B, T, 4*C)\n    floatX* residual3; // (L, B, T, C)\n    floatX* lnf; // (B, T, C);   if LN recomputation is enabled (-r 2 and above), will be used for _all_ layernorms\n    float* lnf_mean; // (B, T)\n    float* lnf_rstd; // (B, T)\n    float* losses; // (B, T), will be accumulated in micro-steps\n    // adding these two compared to the CPU .c code, needed for attention kernel as buffers\n    floatX* qkvr; // (L, B, T, 3*C)\n    // in inference mode, this buffer will store the logits\n    // in training mode, this buffer will contain the *gradients* of the logits.\n    // during the processing of transformer blocks, we will also use this as a\n    // general scratchpad buffer. Allocation is made large enough to hold (B, T, 3C),\n    // (B, NH, T, T), and (B, T, V) shaped tensors.\n    floatX* output;\n\n    // some additional scratch buffers\n    floatX* scratch_bt4c;   // (B, T, 4*C)\n    floatX* scratch_btc;    // (B, T, C)\n} ActivationTensors;\n\n\nstruct TensorSpec {\n    void** ptr;\n    size_t size;\n    DType type;\n};\n\n\n#define TENSOR_SPEC(pointer, size) TensorSpec{(void**)(&pointer), (size), dtype_of(pointer)};\n\nvoid fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS], size_t B, size_t T, GPT2Config config, int recompute) {\n    size_t Vp = config.padded_vocab_size;\n    size_t L = config.num_layers;\n    size_t NH = config.num_heads;\n    size_t C = config.channels;\n    tensors[0] = TENSOR_SPEC(data->encoded, B * T * C);\n    // if recompute >= 1 then we will recompute the layernorm forward activation during backward pass\n    tensors[1] = TENSOR_SPEC(data->ln1,  (recompute < 2) ? L * B * T * C : 0);\n    tensors[2] = TENSOR_SPEC(data->ln1_mean, L * B * T);\n    tensors[3] = TENSOR_SPEC(data->ln1_rstd, L * B * T);\n    tensors[4] = TENSOR_SPEC(data->atty, L * B * T * C);\n    #ifdef ENABLE_CUDNN\n    // FP32 stats tensor for cuDNN to be passed to backward pass\n    tensors[5] = TENSOR_SPEC(data->att, L * B * NH * T);\n    #else\n    tensors[5] = TENSOR_SPEC(data->att, L * B * NH * T * T);\n    #endif\n    tensors[6] = TENSOR_SPEC(data->residual2, L * B * T * C);\n    // if recompute >= 1 then we will recompute the layernorm forward activation during backward pass\n    tensors[7] = TENSOR_SPEC(data->ln2, (recompute < 2) ? L * B * T * C : 0);\n    tensors[8] = TENSOR_SPEC(data->ln2_mean, L * B * T);\n    tensors[9] = TENSOR_SPEC(data->ln2_rstd, L * B * T);\n    tensors[10] = TENSOR_SPEC(data->fch, L * B * T * 4*C);\n    // if recompute >= 1 then we will recompute gelu_forward during backward and use this as scratch buffer\n    tensors[11] = TENSOR_SPEC(data->fch_gelu, (recompute < 1) ? L * B * T * 4*C : B * T * 4*C);\n    tensors[12] = TENSOR_SPEC(data->residual3, L * B * T * C);\n    tensors[13] = TENSOR_SPEC(data->lnf, B * T * C);\n    tensors[14] = TENSOR_SPEC(data->lnf_mean, B * T);\n    tensors[15] = TENSOR_SPEC(data->lnf_rstd, B * T);\n    tensors[16] = TENSOR_SPEC(data->losses, B * T);\n    tensors[17] = TENSOR_SPEC(data->qkvr, L * B * T * 3*C);\n    tensors[18] = TENSOR_SPEC(data->output, B * T * max(3*C, max(NH*T, Vp)));\n\n    tensors[19] = TENSOR_SPEC(data->scratch_bt4c, B * T * 4 * C);\n    tensors[20] = TENSOR_SPEC(data->scratch_btc, B * T * C);\n}\n\nvoid* malloc_and_point_activations(TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS]) {\n    size_t bytes = 0;\n    for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {\n        bytes += tensors[i].size * sizeof_dtype(tensors[i].type);\n    }\n\n    printf0(\"allocating %d MiB for activations\\n\", (int)round(bytes / (1024 * 1024)));\n\n    void* acts_memory;\n    cudaCheck(cudaMalloc((void**)&acts_memory, bytes));\n\n    // cudaMalloc does not guarantee initial memory values so we memset the allocation here\n    // this matters because e.g. non-cuDNN attention assumes the attention buffer is zeroed\n    // todo - up to ~100ms on slow GPUs, could theoretically be more selective, but this is safer\n    cudaCheck(cudaMemset(acts_memory, 0, bytes));\n\n    char* acts_memory_iterator = (char*)acts_memory;\n    for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {\n        // extra protection so we don't accidentally use an empty buffer\n        if(tensors[i].size == 0) {\n            *(tensors[i].ptr) = NULL;\n        }else {\n            *(tensors[i].ptr) = acts_memory_iterator;\n            acts_memory_iterator += tensors[i].size * sizeof_dtype(tensors[i].type);\n        }\n    }\n    return acts_memory;\n}\n\ntypedef struct {\n    GPT2Config config;\n    // the weights of the model, and their sizes\n    ParameterTensors params;\n    size_t param_elements[NUM_PARAMETER_TENSORS];\n    size_t param_sizeof[NUM_PARAMETER_TENSORS];\n    void* params_memory;\n    size_t num_parameters;\n    size_t num_parameters_bytes;\n    // gradients of the weights\n    ParameterTensors grads;\n    void* grads_memory;\n    // buffers for the AdamW optimizer\n    float* m_memory;\n    float* v_memory;\n    float* master_weights;     // is NULL unless fp32 weights is enabled.\n    // the activations of the model, and their sizes\n    ActivationTensors acts;\n    TensorSpec acts_specs[NUM_ACTIVATION_TENSORS];\n    void* acts_memory;\n    // other run state configuration\n    int batch_size; // the batch size (B) of current forward pass\n    int seq_len; // the sequence length (T) of current forward pass\n    int* inputs; // the input tokens for the current forward pass\n    int* targets; // the target tokens for the current forward pass\n    float mean_loss; // after the last backward micro-batch, will be populated with mean loss across all GPUs and micro-steps\n    float* accumulated_mean_loss; // GPU buffer used to accumulate loss across micro-steps\n    float* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost\n    unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc.\n    unsigned long long rng_state_last_update; // RNG before last gpt2_update() to re-round identically from master weights\n    int use_master_weights; // keep master weights copy in float for optim update? 0|1\n    bool init_state;   // set to true if master weights need to be initialized\n    int gelu_fusion; // fuse gelu via cuBLASLt (0=none, 1=forward, 2=forward+backward)\n    int recompute; // recompute gelu | layernorm forward during model backward? 0|1|2\n    // todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch?\n    int* workload_indices; // encoder_backward, B*T*num_c_groups (int)\n    int4* bucket_info;     // encoder_backward, B*T*num_c_groups (int4) - size for worst case\n} GPT2;\n\nvoid gpt2_init_common(GPT2 *model) {\n    // common inits outside of the model weights\n    // memory lazily initialized in forward()\n    model->acts_memory = NULL;\n    model->inputs = NULL;\n    model->targets = NULL;\n    model->accumulated_mean_loss = NULL;\n    model->cpu_losses = NULL;\n    // the B,T params are determined and set, fixed on first batch in forward()\n    model->batch_size = 0;\n    model->seq_len = 0;\n    model->mean_loss = -1.0f; // -1.0f designates no loss, set at end of forward()\n    model->params_memory = NULL;\n    // memory lazily initialized in backward()\n    model->grads_memory = NULL;\n    model->workload_indices = NULL; // on cpu, for encoder_backward\n    model->bucket_info = NULL; // on cpu, for encoder_backward\n    // memory lazily initialized in update()\n    model->m_memory = NULL;\n    model->v_memory = NULL;\n    model->master_weights = NULL;\n    // other default settings\n    model->rng_state = 13371337 + multi_gpu_config.process_rank; // used in stochastic rounding\n    model->use_master_weights = 1; // safe default: do keep master weights in fp32\n    model->init_state = true;\n    model->recompute = 1; // good default: recompute gelu but not layernorm\n    model->gelu_fusion = 0; //deviceProp.major >= 9 ? 2 : 0; // default: off for now (default must match main())\n}\n\nvoid gpt2_allocate_weights(GPT2 *model) {\n    // fill in all the parameter tensor dimensions and types\n    fill_in_parameter_sizes(model->param_elements, model->param_sizeof, model->config);\n    model->num_parameters = 0;\n    model->num_parameters_bytes = 0;\n    for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {\n        model->num_parameters += model->param_elements[i];\n        model->num_parameters_bytes += model->param_elements[i] * model->param_sizeof[i];\n    }\n    // create memory for model parameters on the device\n    assert(model->params_memory == nullptr);\n    model->params_memory = malloc_and_point_parameters(&model->params, model->param_elements, model->param_sizeof);\n}\n\nvoid gpt2_allocate_state(GPT2 *model, int B, int T) {\n    printf0(\"allocating %d MiB for parameter gradients\\n\", (int)round(model->num_parameters * sizeof(floatX) / (1024 * 1024)));\n    assert(model->grads_memory == nullptr);\n    model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_elements, model->param_sizeof);\n\n    // record the current B,T as well\n    model->batch_size = B;\n    model->seq_len = T;\n\n    // allocate the space\n    fill_in_activation_sizes(&model->acts, model->acts_specs, B, T, model->config, model->recompute);\n    model->acts_memory = malloc_and_point_activations(model->acts_specs);\n    // also create memory for caching inputs and targets\n    cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int)));\n    cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int)));\n    cudaCheck(cudaMalloc(((void**)&model->accumulated_mean_loss), sizeof(float)));\n    cudaCheck(cudaMallocHost((void**)&model->cpu_losses, B * T * sizeof(float)));\n\n    // initialise cpu scratch buffers for encoder backward\n    size_t num_c_groups = CEIL_DIV(model->config.channels, (WARP_SIZE * x128::size));\n    assert((size_t)(model->batch_size * model->seq_len) * num_c_groups < (1ULL<<31ULL)); // todo - maybe an issue for llama3-400B(?)\n    model->workload_indices = (int*)mallocCheck(sizeof(int) * model->batch_size * model->seq_len * num_c_groups);\n    model->bucket_info = (int4*)mallocCheck(sizeof(int4) * model->batch_size * model->seq_len * num_c_groups);\n\n    // cudaMallocConditionallyManaged can fall back to cudaMallocManaged if not enough memory on device\n    // and returns a status code of 1 if it had to fall back, in that case we want to print warning.\n    int memory_status = 0;\n\n    // we will now init the optimizer states and master weights\n    // this is usually a substantial amount of memory allocation right here.\n    size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; // num parameters we are responsible for\n    printf0(\"allocating %zu MiB for AdamW optimizer state m\\n\", (shard_num_parameters * sizeof(float)) >> 20);\n    printf0(\"allocating %zu MiB for AdamW optimizer state v\\n\", (shard_num_parameters * sizeof(float)) >> 20);\n    assert(model->m_memory == nullptr);\n    assert(model->v_memory == nullptr);\n    memory_status |= cudaMallocConditionallyManaged((void**)&model->m_memory, shard_num_parameters * sizeof(float));\n    memory_status |= cudaMallocConditionallyManaged((void**)&model->v_memory, shard_num_parameters * sizeof(float));\n\n    if (model->use_master_weights == 1) {\n        assert(model->master_weights == nullptr);\n        printf0(\"allocating %zu MiB for master copy of params\\n\", (shard_num_parameters * sizeof(float)) >> 20);\n        memory_status |= cudaMallocConditionallyManaged((void**) &model->master_weights, shard_num_parameters * sizeof(float));\n    }\n\n    // report on mixed memory allocation status (re-using our float reduce function, bit awk ok)\n    int reduced_memory_status = (int) multi_gpu_cpu_float_sum((float)memory_status, &multi_gpu_config);\n    if (reduced_memory_status >= 1) {\n        printf0(\"WARNING: Fell back to cudaMallocManaged when initializing m,v,master_weights on %d GPUs\\n\", reduced_memory_status);\n        printf0(\"         Prevents an OOM, but code may run much slower due to device <-> host memory movement\\n\");\n    }\n    // report on device memory usage\n    size_t free, total;\n    cudaCheck(cudaMemGetInfo(&free, &total));\n    printf0(\"device memory usage: %zd MiB / %zd MiB\\n\", (total-free) / 1024 / 1024, total / 1024 / 1024);\n    // give an estimate of the maximum batch size\n    size_t bytes_per_sequence = 0;\n    for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {\n        bytes_per_sequence += model->acts_specs[i].size * sizeof_dtype(model->acts_specs[i].type) / B;\n    }\n    printf0(\"memory per sequence: %zu MiB\\n\", bytes_per_sequence / 1024 / 1024);\n    printf0(\" -> estimated maximum batch size: %zu\\n\", B + free / bytes_per_sequence);\n}\n\nvoid gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) {\n    // write the model to a checkpoint file\n    printf0(\"Writing model to %s\\n\", checkpoint_path);\n    FILE *model_file = fopenCheck(checkpoint_path, \"wb\");\n    // write the header first\n    int model_header[256];\n    memset(model_header, 0, sizeof(model_header));\n    model_header[0] = 20240326; // magic number\n    assert(PRECISION_MODE == PRECISION_FP32 || PRECISION_MODE == PRECISION_BF16);\n    model_header[1] = PRECISION_MODE == PRECISION_FP32 ? 3 : 5; // version\n    model_header[2] = model->config.max_seq_len;\n    model_header[3] = model->config.vocab_size;\n    model_header[4] = model->config.num_layers;\n    model_header[5] = model->config.num_heads;\n    model_header[6] = model->config.channels;\n    model_header[7] = model->config.padded_vocab_size;\n    fwriteCheck(model_header, sizeof(int), 256, model_file);\n    // write the parameters\n    device_to_file(model_file, model->params_memory, model->num_parameters_bytes,\n                   IO_BUF_SIZE, main_stream);\n    // close file, we're done\n    fcloseCheck(model_file);\n}\n\nvoid gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool weight_init=true) {\n    // If weight_init is true, we will load the weights from this checkpoint .bin file\n    // We sometimes want this to be false, if we are going to initialize these weights from\n    // the master weights that are instead stored in the state .bin file.\n    // In that case, this function mostly loads the model hyperparameters from the header.\n\n    if (PRECISION_MODE == PRECISION_FP16) {\n        // TODO for later perhaps, would require us dynamically converting the\n        // model weights from fp32 to fp16 online, here in this function, or writing\n        // the fp16 weights directly from Python, which we only do for fp32/bf16 atm.\n        fprintf(stderr, \"build_from_checkpoint() does not support fp16 right now.\\n\");\n        exit(EXIT_FAILURE);\n    }\n\n    // read in model from a checkpoint file\n    FILE *model_file = fopenCheck(checkpoint_path, \"rb\");\n    int model_header[256];\n    freadCheck(model_header, sizeof(int), 256, model_file);\n    if (model_header[0] != 20240326) { printf(\"Bad magic model file\\n\"); exit(EXIT_FAILURE); }\n    int version = model_header[1];\n    if (!(version == 3 || version == 5)) {\n        // 3 = fp32, padded vocab\n        // 5 = bf16, padded vocab, layernorms also in bf16\n        fprintf(stderr, \"Bad version in model file\\n\");\n        fprintf(stderr, \"---> HINT: try to re-run `python train_gpt2.py`\\n\");\n        exit(EXIT_FAILURE);\n    }\n\n    // check if the precision mode of the checkpoing matches the model precision\n    if (weight_init) {\n        if (PRECISION_MODE == PRECISION_BF16 && version != 5) {\n            fprintf(stderr, \"Precision is configured as BF16 but model at %s is not.\\n\", checkpoint_path);\n            fprintf(stderr, \"---> HINT: are you sure you're loading a _bf16.bin file?\\n\");\n            exit(EXIT_FAILURE);\n        }\n        if (PRECISION_MODE == PRECISION_FP32 && version != 3) {\n            fprintf(stderr, \"Precision is configured as FP32 but model at %s is not.\\n\", checkpoint_path);\n            fprintf(stderr, \"---> HINT: to turn on FP32 you have to compile like: `make train_gpt2cu PRECISION=FP32`\\n\");\n            fprintf(stderr, \"---> HINT: are you sure you're loading a .bin file without any _bf16 in the name?\\n\");\n            exit(EXIT_FAILURE);\n        }\n    }\n\n    // read in hyperparameters\n    model->config.max_seq_len = model_header[2];\n    model->config.vocab_size = model_header[3];\n    model->config.num_layers = model_header[4];\n    model->config.num_heads = model_header[5];\n    model->config.channels = model_header[6];\n    model->config.padded_vocab_size = model_header[7];\n\n    // allocate memory for the model parameters\n    gpt2_allocate_weights(model);\n\n    // read in the parameters if weight_init is true\n    if (weight_init) {\n        assert(model->params_memory != NULL);\n        file_to_device(model->params_memory, model_file, model->num_parameters_bytes, IO_BUF_SIZE, main_stream);\n    }\n    fcloseCheck(model_file);\n\n    // only return from this function once we are certain the params are ready on the GPU\n    cudaCheck(cudaDeviceSynchronize());\n}\n\nvoid gpt2_set_hyperparameters(GPT2Config* config, const char* depth_str) {\n    int depth = atoi(depth_str);\n    assert(depth > 0); // atoi returns 0 if not a number\n    int channels, num_heads;\n    if      (depth == 6)  { channels = 384; num_heads = 6; }   // (unofficial) gpt2-tiny (30M)\n    else if (depth == 12) { channels = 768; num_heads = 12; }  // gpt2 (124M)\n    else if (depth == 24) { channels = 1024; num_heads = 16; } // gpt2-medium (350M)\n    else if (depth == 36) { channels = 1280; num_heads = 20; } // gpt2-large (774M)\n    else if (depth == 48) { channels = 1600; num_heads = 25; } // gpt2-xl (1558M)\n    else if (depth == 60) { channels = 1920; num_heads = 30; } // (unofficial) 2.7B\n    else if (depth == 72) { channels = 2880; num_heads = 30; } // (unofficial) 7.3B\n    else if (depth == 84) { channels = 3456; num_heads = 36; } // (unofficial) 12.2B\n    else { fprintf(stderr, \"Unsupported GPT-2 depth: %d\\n\", depth); exit(EXIT_FAILURE); }\n    config->num_layers = depth;\n    config->channels = channels;\n    config->num_heads = num_heads;\n    config->max_seq_len = 1024;\n}\n\nvoid gpt3_set_hyperparameters(GPT2Config* config, const char* channels_str) {\n    // we use channels instead of depth for GPT-3 because GPT-3 model depths are not one-to-one\n    // note that our models are not necessarily identical to GPT-3 because\n    // we use dense attention, not the alternating dense/banded attention of GPT-3\n    int channels = atoi(channels_str);\n    assert(channels > 0); // atoi returns 0 if not a number\n    int depth, head_size;\n    if      (channels == 384)   { depth = 6;  head_size = 64; }  // (unofficial) gpt3-tiny (31M)\n    else if (channels == 768)   { depth = 12; head_size = 64; }  // gpt3-small (125M)\n    else if (channels == 1024)  { depth = 24; head_size = 64; }  // gpt3-medium (350M)\n    else if (channels == 1536)  { depth = 24; head_size = 96; }  // gpt3-large (760M)\n    else if (channels == 2048)  { depth = 24; head_size = 128; } // gpt3-xl (1.3B) [heads fixed]\n    else if (channels == 2560)  { depth = 32; head_size = 80; }  // gpt3-2.7B\n    else if (channels == 4096)  { depth = 32; head_size = 128; } // gpt3-6.7B\n    else if (channels == 5140)  { depth = 40; head_size = 128; } // gpt3-13B\n    else if (channels == 12288) { depth = 96; head_size = 128; } // gpt3 (175B)\n    else { fprintf(stderr, \"Unsupported GPT-3 channels: %d\\n\", channels); exit(EXIT_FAILURE); }\n    assert(channels % head_size == 0);\n    config->num_layers = depth;\n    config->channels = channels;\n    config->num_heads = channels / head_size;\n    config->max_seq_len = 2048; // NOTE: GPT-3 uses context length of 2048 tokens, up from 1024 in GPT-2\n}\n\nvoid gpt_build_from_descriptor(GPT2 *model, const char* descriptor) {\n    // The model descriptor can be:\n    // - legacy format \"dX\", where X is number, e.g. \"d12\". This creates GPT-2 model with 12 layers.\n    // - new explicit format \"gpt2:dX\", same as above, e.g. \"gpt2:d48\" for GPT-2 with 48 layers.\n    // - \"gpt3:cX\", where X is now the channel count, e.g. \"gpt3:c768\" is the smallest GPT-3 model.\n\n    // check the valid prexies and dispatch to the right setup function\n    assert(descriptor != NULL);\n    size_t len = strlen(descriptor);\n    if (len > 1 && descriptor[0] == 'd') {\n        gpt2_set_hyperparameters(&model->config, descriptor + 1); // pass along the depth str without the 'd'\n    } else if (len > 6 && strncmp(descriptor, \"gpt2:d\", 6) == 0) {\n        gpt2_set_hyperparameters(&model->config, descriptor + 6); // pass along the depth str without the 'gpt2:d'\n    } else if (len > 6 && strncmp(descriptor, \"gpt3:c\", 6) == 0) {\n        gpt3_set_hyperparameters(&model->config, descriptor + 6); // pass along the channels str without the 'gpt3:c'\n    } else {\n        fprintf(stderr, \"Unsupported model descriptor: %s\\n\", descriptor); exit(EXIT_FAILURE);\n    }\n\n    // both GPT-2 and GPT-3 use the same tokenizer with 50257 tokens\n    model->config.vocab_size = 50257;\n    model->config.padded_vocab_size = 50304; // padded to 128 for CUDA kernel efficiency\n\n    gpt2_allocate_weights(model);\n\n    // allocate and random init the memory for all the parameters with GPT-2 schema\n    // weights ~N(0, 0.02), biases 0, c_proj weights ~N(0, 0.02/(2*L)**0.5)\n    // NOTE: assuming all parameters are of the type floatX, could be relaxed later\n    mt19937_state init_rng;\n    manual_seed(&init_rng, 42);\n    floatX* params_memory_cpu = (floatX*)mallocCheck(model->num_parameters_bytes);\n    memset(params_memory_cpu, 0, model->num_parameters_bytes);\n    // fill in all the weights with random values\n    float residual_scale = 1.0f / sqrtf(2.0f * model->config.num_layers);\n    // we have to init all these tensors exactly in the order that PyTorch initializes them\n    // so that we can match them up and get correctness and exactly the same initial conditions\n    size_t L = model->config.num_layers;\n    size_t offset = 0;\n    for (int l = 0; l < L; l++) {\n        offset = 0;\n        for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {\n            // the layernorm parameters are all initialized to 1\n            if (l == 0 && (i == 2 || i == 8 || i == 14)) { // only at l = 0 to init these just once\n                for (size_t j = 0; j < model->param_elements[i]; j++) {\n                    params_memory_cpu[offset + j] = 1.0f;\n                }\n            }\n            // weights tensors are handled here\n            if ((l == 0 && (i == 0 || i == 1)) // only at l = 0, init the wte and wpe tensors\n              || i == 4 || i == 6 || i == 10 || i == 12) {\n                size_t n = model->param_elements[i];\n                size_t layer_offset = 0;\n                if (i == 0) {\n                    // for wte tensor (padded vocab) override to init V instead of Vp rows\n                    n = model->config.vocab_size * model->config.channels;\n                }\n                if (i == 4 || i == 6 || i == 10 || i == 12) {\n                    // weight tensors, we are only initializing layer l\n                    assert(n % L == 0);\n                    n = n / L;\n                    layer_offset = l * n;\n                }\n                // in GPT-2, the projections back into the residual stream are additionally\n                // scaled by 1/sqrt(2*L) for training stability\n                float scale = (i == 6 || i == 12) ? 0.02f * residual_scale : 0.02f;\n                // okay let's draw the random numbers and write them\n                float *fp32_buffer = (float*)mallocCheck(n * sizeof(float));\n                normal_(fp32_buffer, n, 0.0f, scale, &init_rng);\n                for (size_t j = 0; j < n; j++) {\n                    params_memory_cpu[offset + layer_offset + j] = (floatX)fp32_buffer[j];\n                }\n                free(fp32_buffer);\n            }\n            offset += model->param_elements[i];\n        }\n    }\n\n    // copy them to GPU\n    cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice));\n    free(params_memory_cpu);\n}\n\n// propagate inputs through the network to produce logits.\n// right now, this function is fully synchronous with the host\nvoid gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {\n    NVTX_RANGE_FN();\n    // we must be careful and use size_t instead of int, otherwise\n    // we could overflow int. E.g. l * B * NH * T * T overflows int at B 16.\n\n    // ensure the model was initialized or error out\n    if (model->params_memory == NULL) {\n        printf(\"Error: model was not initialized properly.\\n\");\n        exit(EXIT_FAILURE);\n    }\n\n    // convenience parameters\n    const size_t V = model->config.vocab_size;\n    const size_t Vp = model->config.padded_vocab_size;\n    const size_t L = model->config.num_layers;\n    const size_t NH = model->config.num_heads;\n    const size_t C = model->config.channels;\n\n    // validate B,T are not larger than the values used at initialisation\n    // (smaller B,T are okay for inference only)\n    if (B > model->batch_size || T > model->seq_len) {\n        printf(\"Model: B=%d T=%d, Desired: B=%d T=%d\\n\", model->batch_size, model->seq_len, (int)B, (int)T);\n        exit(EXIT_FAILURE);\n    }\n\n    // copy inputs/targets to the model\n    cudaCheck(cudaMemcpy(model->inputs, inputs, B * T * sizeof(int), cudaMemcpyHostToDevice));\n    // validate inputs, all indices must be in the range [0, V)\n    // we can do this while the copies are already underway\n    tokenCheck(inputs, B*T, V);\n\n    // forward pass\n    ParameterTensors params = model->params; // for brevity\n    ActivationTensors acts = model->acts;\n    encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C, main_stream); // encoding goes into residual[0]\n\n    // first layernorm isn't fused\n    layernorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, B, T, C, main_stream);\n\n    for (int l = 0; l < L; l++) {\n        NvtxRange layer_range(\"Layer\", l);\n\n        floatX* residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;\n\n        // get the pointers of the weights for this layer\n        floatX* l_qkvw = params.qkvw + l * 3*C * C;\n        floatX* l_qkvb = params.qkvb + l * 3*C;\n        floatX* l_attprojw = params.attprojw + l * C * C;\n        floatX* l_attprojb = params.attprojb + l * C;\n        floatX* l_ln2w = params.ln2w + l * C;\n        floatX* l_ln2b = params.ln2b + l * C;\n        floatX* l_fcw = params.fcw + l * 4*C * C;\n        floatX* l_fcb = params.fcb + l * 4*C;\n        floatX* l_fcprojw = params.fcprojw + l * C * 4*C;\n        floatX* l_fcprojb = params.fcprojb + l * C;\n\n        // get the pointers of the activations for this layer\n        floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf;\n        floatX* l_qkvr = acts.qkvr + l * B * T * 3*C;\n        floatX* l_atty = acts.atty + l * B * T * C;\n        floatX* l_residual2 = acts.residual2 + l * B * T * C;\n        floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf;\n        float* l_ln2_mean = acts.ln2_mean + l * B * T;\n        float* l_ln2_rstd = acts.ln2_rstd + l * B * T;\n        floatX* l_fch = acts.fch + l * B * T * 4*C;\n        // reuse the same activation buffer at each layer, as we'll re-compute the gelu during backward\n        // very useful because we dramatically reduce VRAM usage, and may be able to fit larger batch size\n        floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu;\n        floatX* l_residual3 = acts.residual3 + l * B * T * C;\n        floatX* scratch = (floatX*)acts.output; // used for non-cudnn attention, fcproj, attproj, etc.\n\n        // now do the forward pass\n        #ifdef ENABLE_CUDNN\n        float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor\n        matmul_forward_cublaslt(l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream);\n        attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, C, main_stream);\n        #else\n        floatX* l_att = acts.att + l * B * NH * T * T;\n        if (T != model->seq_len) { // unused parts of attention buffer must be zeroed (T-dependent)\n            cudaCheck(cudaMemset(l_att, 0, B * NH * T * T * sizeof(floatX)));\n        }\n        // these are only needed as scratchpads for the forward pass, but\n        // need not be stored for backward\n        matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream);\n        attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH, main_stream);\n        #endif\n\n        matmul_forward_cublaslt(scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream);\n        fused_residual_forward5(l_residual2, l_ln2, l_ln2_mean, l_ln2_rstd, residual, scratch, l_ln2w, l_ln2b, B*T, C, main_stream);\n        matmul_forward_cublaslt(l_fch_gelu, l_ln2, l_fcw, l_fcb, B, T, C, 4*C, main_stream, l_fch, model->gelu_fusion);\n        matmul_forward_cublaslt(scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C, main_stream);\n        // OK, fusion across blocks.\n        if(l+1 != L) {\n            floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + (l + 1) * B * T * C : acts.lnf;\n            float* l_ln1_mean = acts.ln1_mean + (l + 1) * B * T;\n            float* l_ln1_rstd = acts.ln1_rstd + (l + 1) * B * T;\n            const floatX* l_ln1w = params.ln1w + (l + 1) * C;\n            const floatX* l_ln1b = params.ln1b + (l + 1) * C;\n            fused_residual_forward5(l_residual3, l_ln1, l_ln1_mean, l_ln1_rstd, l_residual2, scratch, l_ln1w, l_ln1b,\n                                    B * T, C, main_stream);\n        } else {\n            fused_residual_forward5(l_residual3, acts.lnf, acts.lnf_mean, acts.lnf_rstd, l_residual2, scratch,\n                                    params.lnfw, params.lnfb,\n                                    B * T, C, main_stream);\n        }\n    }\n\n    matmul_forward_cublaslt(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream);\n    cudaCheck(cudaDeviceSynchronize());\n}\n\n\n// Forwards both the model and the loss and is used for validation splits and evals.\n// In particular it populates cpu_losses with loss at each token.\n// Some of the evals (e.g. HellaSwag) require the per-token losses, which are produced here.\nfloat gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B, size_t T) {\n    assert(targets != NULL);\n    // forward the model itself\n    gpt2_forward(model, inputs, B, T);\n    // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow\n    const size_t V = model->config.vocab_size;\n    const size_t Vp = model->config.padded_vocab_size;\n\n    NvtxRange classifier_and_loss_range(\"classifier_and_loss\");\n    ActivationTensors acts = model->acts;\n    float mean_loss = 0.0f;\n    // fused classifier: does the forward pass and first part of the backward pass\n    const float dloss = 1.0f / (B * T); // results in the uniform average loss over all elements\n    // note: we don't need to generate dlogits here\n    cudaCheck(cudaMemset(acts.losses, 0, B*T*sizeof(float)));\n    cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice));\n    tokenCheck(targets, B*T, V); // while the memcpy is underway, validate the targets\n    fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, False, main_stream);\n    cudaCheck(cudaMemcpy(model->cpu_losses, acts.losses, B * T * sizeof(float), cudaMemcpyDeviceToHost));\n    for (int i = 0; i < B*T; i++) {\n        mean_loss += model->cpu_losses[i];\n    }\n    mean_loss /= B*T;\n    cudaCheck(cudaDeviceSynchronize());\n    return mean_loss;\n}\n\nvoid gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, int micro_step) {\n    if(model->grads_memory == nullptr) {\n        fprintf(stderr, \"Need to allocate gradients before backward\");\n        exit(EXIT_FAILURE);\n    }\n    NVTX_RANGE_FN();\n    bool last_step = micro_step == grad_accum_steps - 1;\n    // on the first micro-step zero the gradients, as we're about to += accumulate into them\n    if (micro_step == 0) {\n        // there are currently two state vars during the gradient accumulation inner loop:\n        // 1) the losses accumulate += into acts.losses, reset here\n        // 2) the gradients accumulate += into grads_memory, reset here\n        cudaCheck(cudaMemsetAsync(model->acts.losses, 0, model->batch_size * model->seq_len * sizeof(float), main_stream));\n        cudaCheck(cudaMemsetAsync(model->grads_memory, 0, model->num_parameters * sizeof(floatX), main_stream));\n    }\n\n    // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow\n    const size_t B = model->batch_size;\n    const size_t T = model->seq_len;\n    const size_t V = model->config.vocab_size;\n    const size_t Vp = model->config.padded_vocab_size;\n    const size_t L = model->config.num_layers;\n    const size_t NH = model->config.num_heads;\n    const size_t C = model->config.channels;\n\n    ParameterTensors params = model->params; // for brevity\n    ParameterTensors grads = model->grads;\n    ActivationTensors acts = model->acts;\n\n    // accumulate the losses inside acts.losses, and kick off the backward pass inside the fused classifier\n    NvtxRange classifier_and_loss_range(\"classifier_and_loss\");\n    const float dloss = 1.0f / (float)(B * T * grad_accum_steps); // results in the uniform average loss over all elements\n    cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice));\n    tokenCheck(targets, B*T, V);\n    fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, True, main_stream);\n\n    // backward pass: go in the reverse order of the forward pass, and call backward() functions\n\n    // reset residual stream gradients (put here to work with gradient accumulation)\n    floatX* dresidual = (floatX*)model->acts.scratch_btc; // the main buffer holding the gradient in the backward pass\n    cudaCheck(cudaMemset(dresidual, 0, B * T * C * sizeof(floatX)));\n\n    // re-use the output buffer of the forward pass as a scratchpad during backward pass\n    float*  scratchF = (float*)acts.output;\n    floatX* scratchX = (floatX*)acts.output;\n\n    // we kick off the chain rule by filling in dlosses with 1.0f/(B*T)\n    // this was done in the fused classifier kernel as last step of forward pass\n    // technically that is a small, inline backward() pass of calculating\n    // total, final loss as the mean over all losses over all (B,T) positions in the batch\n    // next: backward the classifier matmul\n    matmul_backward(model->acts.scratch_bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream);\n    // backward the final layernorm\n    floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3\n    layernorm_backward(dresidual, grads.lnfw, grads.lnfb, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C, main_stream);\n\n    // from this point on, we no longer need the values stored in the last residual, so we can reuse that memory as generic\n    // scratch for backward computations\n    floatX* dl_btc = residual;\n\n    // now backward all the layers\n    for (int l = L-1; l >= 0; l--) {\n        NvtxRange layer_range(\"Layer\", l);\n\n        residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;\n\n        // get the pointers of the weights for this layer\n        floatX* l_ln1w = params.ln1w + l * C;\n        floatX* l_ln1b = params.ln1b + l * C;\n        floatX* l_qkvw = params.qkvw + l * 3*C * C;\n        floatX* l_attprojw = params.attprojw + l * C * C;\n        floatX* l_ln2w = params.ln2w + l * C;\n        floatX* l_ln2b = params.ln2b + l * C;\n        floatX* l_fcw = params.fcw + l * 4*C * C;\n        floatX* l_fcprojw = params.fcprojw + l * C * 4*C;\n        // get the pointers of the gradients of the weights for this layer\n        floatX* dl_ln1w = grads.ln1w + l * C;\n        floatX* dl_ln1b = grads.ln1b + l * C;\n        floatX* dl_qkvw = grads.qkvw + l * 3*C * C;\n        floatX* dl_qkvb = grads.qkvb + l * 3*C;\n        floatX* dl_attprojw = grads.attprojw + l * C * C;\n        floatX* dl_attprojb = grads.attprojb + l * C;\n        floatX* dl_ln2w = grads.ln2w + l * C;\n        floatX* dl_ln2b = grads.ln2b + l * C;\n        floatX* dl_fcw = grads.fcw + l * 4*C * C;\n        floatX* dl_fcb = grads.fcb + l * 4*C;\n        floatX* dl_fcprojw = grads.fcprojw + l * C * 4*C;\n        floatX* dl_fcprojb = grads.fcprojb + l * C;\n        // get the pointers of the activations for this layer\n        floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf;\n        float* l_ln1_mean = acts.ln1_mean + l * B * T;\n        float* l_ln1_rstd = acts.ln1_rstd + l * B * T;\n        floatX* l_qkvr = acts.qkvr + l * B * T * 3*C;\n        floatX* l_atty = acts.atty + l * B * T * C;\n        floatX* l_residual2 = acts.residual2 + l * B * T * C;\n        floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf;\n        float* l_ln2_mean = acts.ln2_mean + l * B * T;\n        float* l_ln2_rstd = acts.ln2_rstd + l * B * T;\n        floatX* l_fch_pre_gelu = acts.fch + l * B * T * 4*C;\n        floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu;\n        // get the pointers of the gradients of the activations for this layer\n        // notice that there is no l *, because we just have a single copy, and keep\n        // re-using this memory in every Transformer block as we calculate backward pass\n\n        floatX* dl_bt4c = (floatX*)model->acts.scratch_bt4c;\n\n        // start the backward pass for this layer\n        if(model->recompute >= 1) {\n            // recompute >= 1 means we recompute gelu. in this case,\n            // l_fch_gelu is just a buffer, so re-compute the gelu from l_fch here\n            gelu_forward(l_fch_gelu, l_fch_pre_gelu, B*T*4*C, main_stream);\n        }\n        matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, scratchF, B, T, 4*C, C, main_stream, l_fch_pre_gelu, model->gelu_fusion);\n        if(model->recompute >= 2) {\n            // same as gelu above, l_ln1 and l_ln2 are just buffers if recompute >= 2, recompute them here on demand\n            layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C, main_stream);\n        }\n        matmul_backward(dl_btc, dl_fcw, dl_fcb, dl_bt4c, l_ln2, l_fcw, scratchF, B, T, C, 4 * C, main_stream);\n        // layernorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above\n        layernorm_backward(dresidual, dl_ln2w, dl_ln2b, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C, main_stream);\n        matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream);\n\n        #ifdef ENABLE_CUDNN\n        float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor\n        attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, C, main_stream);\n        #else\n        floatX* l_att = acts.att + l * B * NH * T * T;\n        // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory\n        floatX* buffer_a = l_atty;\n        floatX* buffer_b = l_fch_pre_gelu;        // this is B x T x 4C, so even larger than what we need\n        attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream);\n        #endif\n        if(model->recompute >= 2) {\n            layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C, main_stream);\n        }\n        // QKV parameter gradients\n        matmul_backward(dl_btc, dl_qkvw, dl_qkvb, dl_bt4c, l_ln1, l_qkvw, scratchF, B, T, C, 3 * C, main_stream);\n        // layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above\n        layernorm_backward(dresidual, dl_ln1w, dl_ln1b, scratchF, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C, main_stream);\n\n        // Accumulate gradients from this layer in a background stream.\n        if(last_step) {\n            floatX* const pointers[] = {\n                dl_ln1w, dl_ln1b,\n                dl_qkvw, dl_qkvb,\n                dl_attprojw, dl_attprojb,\n                dl_ln2w, dl_ln2b,\n                dl_fcw, dl_fcb,\n                dl_fcprojw, dl_fcprojb\n            };\n            const size_t nelem[] = {\n                C, C,\n                3 * C * C, 3 * C,\n                C * C, C,\n                C, C,\n                4 * C * C, 4 * C,\n                C * 4 * C, C\n            };\n            multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream);\n        }\n    }\n    encoder_backward(grads.wte, grads.wpe, scratchX, model->workload_indices, model->bucket_info,\n                     dresidual, model->inputs, inputs, B, T, C, random_u32(&model->rng_state), main_stream);\n\n    // Aggregate all gradients that are not part of the transformer blocks\n    if(last_step) {\n        // reduce all the losses within the current GPU (across all microsteps)\n        global_sum_deterministic(model->accumulated_mean_loss, acts.losses, B*T, main_stream);\n        // reduce loss across GPUs to a single, final float across all microsteps and GPUs\n        #if MULTI_GPU\n        ncclCheck(ncclAllReduce(model->accumulated_mean_loss, model->accumulated_mean_loss, sizeof(float), ncclFloat, ncclAvg, multi_gpu_config.nccl_comm, main_stream));\n        #endif\n        cudaCheck(cudaMemcpyAsync(&model->mean_loss, model->accumulated_mean_loss, sizeof(float), cudaMemcpyDeviceToHost, main_stream));\n        // reduce the gradients for non-transformer block parameters\n        floatX* const pointers[] = {grads.wte, grads.wpe, grads.lnfw, grads.lnfb};\n        const size_t nelem[] = {Vp * C, T * C, C, C};\n        multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream);\n    }\n\n    cudaCheck(cudaDeviceSynchronize());\n    if(last_step) {\n        model->mean_loss /= B*T*grad_accum_steps;\n    } else {\n        model->mean_loss = -1.f; // no loss available yet\n    }\n}\n\n// Gets the offset of a specific tensor for a specific layer in the GPT2 model\n// layer_id is ignored for weights that are not part of a transformer block\nShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_tensor_id) {\n    // first offset our way to the parameter tensor start\n    ptrdiff_t offset = 0;\n    for (int i = 0; i < param_tensor_id; i++) {\n        offset += (ptrdiff_t)model->param_elements[i];\n    }\n    size_t size = model->param_elements[param_tensor_id] ;\n    // if we are in the transformer block, we need to additionally offset by the layer id\n    if(2 <= param_tensor_id && param_tensor_id <= 13) {\n        size /= model->config.num_layers;\n        offset += (ptrdiff_t)(layer_id * size);\n    }\n    return {offset, size};\n}\n\nfloat gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) {\n    NVTX_RANGE_FN();\n    floatX* grads_memory = (floatX*)model->grads_memory;\n\n    // repurposing this buffer (which isn't needed now) to write grad norm into it\n    float* grad_norm_squared = (float*)model->acts.output;\n    float grad_norm_squared_cpu = 0.0f;\n\n    int num_slices[2] = {1, model->config.num_layers};\n    int max_num_block_sums = get_max_num_block_sums(num_slices, 2);\n    if (multi_gpu_config->zero_stage == 1) {\n        // because of the ncclReduceScatter() in backward,\n        // grads_memory only contains the averaged gradients at the local shards,\n        // so we only calculate the grad norm at the grads_memory belonging to the local shards\n        for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {\n            ShardInfo tensor = gpt2_get_tensor_at_layer(model, 0, i);\n            ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1);\n            ptrdiff_t offset = tensor.offset + shard.offset;\n            bool is_first_pass = (i == 0);\n            if((i < 2 || i > 13)) {\n                global_norm_squared(grad_norm_squared, grads_memory + offset, shard.size, 0, 1,\n                                    max_num_block_sums, is_first_pass, main_stream);\n            } else {\n                global_norm_squared(grad_norm_squared, grads_memory + offset, shard.size, tensor.size, model->config.num_layers,\n                                    max_num_block_sums, is_first_pass, main_stream);\n            }\n        }\n        global_sum_deterministic(grad_norm_squared, grad_norm_squared, max_num_block_sums, main_stream);\n#if MULTI_GPU\n        // further sum the (partial) squared norm across all GPUs\n        ncclCheck(ncclAllReduce(grad_norm_squared, grad_norm_squared, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, main_stream));\n#endif\n    } else {\n        // in regular DDP, backward has averaged the gradients across all GPUs\n        // so each GPU can compute the squared norm over the whole grad vector, with no added comms needed\n        global_norm_squared(grad_norm_squared, grads_memory, model->num_parameters, 0, 1, max_num_block_sums, true, main_stream);\n        global_sum_deterministic(grad_norm_squared, grad_norm_squared, max_num_block_sums, main_stream);\n    }\n    cudaCheck(cudaMemcpy(&grad_norm_squared_cpu, grad_norm_squared, sizeof(float), cudaMemcpyDeviceToHost));\n    float grad_norm_cpu = sqrtf(grad_norm_squared_cpu);\n    return grad_norm_cpu;\n}\n\nvoid gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t,\n                 MultiGpuConfig* multi_gpu_config, bool init_from_master_only=false) {\n    // update the model parameters using the AdamW optimizer\n    // keep in mind that optimizer sharding (ZeRO-1) assigns different parameters to different GPUs\n    // so we may not be responsible for the entire parameter tensor\n    // also, this function was very simple a while back but become very complex, only because we want to\n    // selectively weight decay some, but not all tensors :(\n    // TODO: revisit and probably refactor this entire function\n    NVTX_RANGE_FN();\n    if(model->grads_memory == nullptr || model->m_memory == nullptr || model->v_memory == nullptr) {\n        fprintf(stderr, \"Need to allocate optimizer state before update\");\n        exit(EXIT_FAILURE);\n    }\n\n    bool init_state = model->init_state;\n    if(init_state) {\n        model->init_state = false;\n        NvtxRange rng(\"InitOpt\");\n        cudaCheck(cudaMemset(model->m_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float)));\n        cudaCheck(cudaMemset(model->v_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float)));\n    }\n\n    // save RNG state at this point so we can round from master weights identically when restoring from a checkpoint\n    model->rng_state_last_update = model->rng_state;\n\n    // AdamW update\n    // handle adamw for all the transformer blocks\n    for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {\n        // generate a unique seed for each tensor\n        unsigned int seed = random_u32(&model->rng_state);\n\n        int num_layers = model->config.num_layers;\n        if((i < 2 || i > 13)) {\n            num_layers = 1;\n        }\n\n        ShardInfo tensor = gpt2_get_tensor_at_layer(model, 0, i);\n        ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1);\n        ptrdiff_t local_offset_full = tensor.offset + shard.offset;\n        ptrdiff_t local_offset_partial = tensor.offset / multi_gpu_config->num_processes;\n\n        // we only want to weight decay the 2D tensors and leave all 1D tensors alone\n        // in particular this also decays the embedding weights, but this is ok:\n        // - the token embeddings are weight shared and participate in the final projection to logits\n        // - the position embeddings actively participate at every forward/backward pass\n        float wd = (i == 0 || i == 1 || i == 4 || i == 6 || i == 10 || i == 12) ? weight_decay : 0.0f;\n        floatX* param_ptr = (floatX*)model->params_memory + local_offset_full;\n        floatX* grad_ptr = (floatX*)model->grads_memory + local_offset_full;\n\n        ptrdiff_t opt_state_offset = multi_gpu_config->zero_stage < 1 ?  local_offset_full : local_offset_partial;\n        float* m_ptr = model->m_memory + opt_state_offset;\n        float* v_ptr = model->v_memory + opt_state_offset;\n        float* master_ptr = nullptr;\n        if (model->master_weights != nullptr) { master_ptr = model->master_weights + opt_state_offset; }\n        if(init_state && model->master_weights != nullptr ) {\n            size_t grid_size = CEIL_DIV(shard.size, 512);\n            copy_and_cast_kernel<<<dim3(grid_size, num_layers), 512, 0, main_stream>>>(master_ptr, param_ptr, shard.size,\n                                                                     shard.size, tensor.size);\n            cudaCheck(cudaGetLastError());\n        }\n\n        if (init_from_master_only) {\n            // when resuming training from a checkpoint with master weights (allows changing precision)\n            init_from_master(param_ptr, master_ptr, shard.size, tensor.size, shard.size, num_layers, seed, main_stream);\n        } else {\n            // ok finally call the kernel to update the weights with AdamW\n            adamw_update(param_ptr, master_ptr, grad_ptr,\n                        m_ptr, v_ptr,\n                        shard.size, tensor.size, tensor.size, shard.size, num_layers,\n                        learning_rate,\n                        beta1, beta2, t, eps, wd, grad_scale, seed, main_stream);\n        }\n\n        if (multi_gpu_config->zero_stage == 1) {\n#if MULTI_GPU\n            ncclCheck(ncclGroupStart());\n            for(int l = 0; l < num_layers; ++l) {\n                // gather updated shards of model->params_memory from each process\n                ncclCheck(ncclAllGather(param_ptr + l * tensor.size,\n                                        (floatX*) model->params_memory + tensor.offset + l * tensor.size,\n                                        shard.size, ncclFloatX,\n                                        multi_gpu_config->nccl_comm, multi_gpu_config->nccl_stream));\n            }\n            ncclCheck(ncclGroupEnd());\n#endif\n        }\n    }\n\n    cudaCheck(cudaDeviceSynchronize());\n}\n\nfloat gpt2_estimate_mfu(GPT2 *model, int num_tokens, float dt) {\n    /*\n    Estimate model flops utilization (MFU)\n    ref: Section 2.1 of https://arxiv.org/pdf/2001.08361\n    Note: Ideally, the N here would be only the parameters that actually\n    participate in matrix multiplications. In this N, we are over-estimating by\n    including LayerNorm params, biases, and the position embedding weights,\n    but these are very small terms. Also keep in mind that we would want to exclude\n    the token embedding weights, but in GPT-2 these are weight shared, so they\n    participate in the classifier matmul, so they are correct to be included in N.\n    Note 2: The first term (6 * N) in flops_per_token is all weight matmuls, the\n    second is the attention matmul, which is also usually a small contribution.\n    */\n    size_t N = model->num_parameters;\n    int L = model->config.num_layers;\n    int C = model->config.channels;\n    int T = model->seq_len;\n    size_t flops_per_token = 6 * N + (size_t)6 * L * C * T;\n    size_t flops_per_step = flops_per_token * num_tokens;\n    // express our flops throughput as ratio of A100 bfloat16 peak flops\n    float flops_achieved = (float)flops_per_step * (1.0f / dt); // per second\n    float flops_promised = get_flops_promised(deviceProp.name, PRECISION_MODE) * 1e12f;\n    if(flops_promised < 0) {\n        return -1.f;   // don't know\n    }\n    float mfu = flops_achieved / flops_promised;\n    return mfu;\n}\n\nvoid gpt2_free(GPT2 *model) {\n    cudaFreeCheck(&model->params_memory);\n    cudaFreeCheck(&model->grads_memory);\n    cudaFreeCheck(&model->m_memory);\n    cudaFreeCheck(&model->v_memory);\n    cudaFreeCheck(&model->master_weights);\n    cudaFreeCheck(&model->acts_memory);\n    cudaFreeCheck(&model->inputs);\n    cudaFreeCheck(&model->targets);\n    cudaFreeCheck(&model->accumulated_mean_loss);\n    cudaCheck(cudaFreeHost(model->cpu_losses));\n    free(model->workload_indices);\n    free(model->bucket_info);\n}\n\n// ----------------------------------------------------------------------------\n// common init & free code for all of train/test/profile\n\nvoid common_start(bool override_enable_tf32 = true, bool print_device_info = true) {\n\n    // get CUDA device infos\n    cudaCheck(cudaGetDeviceProperties(&deviceProp, multi_gpu_config.local_device_idx));\n    if (print_device_info) {\n        printf(\"[System]\\n\");\n        printf(\"Device %d: %s\\n\", multi_gpu_config.local_device_idx, deviceProp.name);\n    }\n\n    // set up the cuda streams. atm everything is on the single main stream\n    cudaCheck(cudaStreamCreate(&main_stream));\n    nvtxNameCudaStreamA(main_stream, \"main stream\");\n\n    // set up cuBLAS and cuBLASLt\n    cublasCheck(cublasLtCreate(&cublaslt_handle));\n    cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size));\n\n    // TF32 precision is equivalent to torch.set_float32_matmul_precision('high')\n    bool enable_tf32 = PRECISION_MODE == PRECISION_FP32 && deviceProp.major >= 8 && override_enable_tf32;\n    cublas_compute = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F;\n\n    #ifdef ENABLE_CUDNN\n    create_cudnn();\n    #endif\n}\n\nvoid common_free(GPT2 &model) {\n    cudaCheck(cudaStreamDestroy(main_stream));\n    cudaCheck(cudaFree(cublaslt_workspace));\n    cublasCheck(cublasLtDestroy(cublaslt_handle));\n    #ifdef ENABLE_CUDNN\n    destroy_cudnn();\n    #endif\n}\n\n\nvoid save_state(const char* filename, int step, GPT2* model, DataLoader* loader) {\n    printf(\"Writing state to %s\\n\", filename);\n    FILE *state_file = fopenCheck(filename, \"wb\");\n    int state_header[256];\n    memset(state_header, 0, sizeof(state_header));\n    // basic identifying information\n    state_header[0] = 20240527; // magic number\n    state_header[1] = 1; // version number\n    state_header[2] = multi_gpu_config.num_processes; // number of processes\n    state_header[3] = multi_gpu_config.process_rank; // rank of this process\n    state_header[4] = model->use_master_weights;  // whether we're using fp32 master weights\n    state_header[5] = loader->should_shuffle; // shuffle state of the dataloader\n    // int main state, start at 10 to leave some padding\n    state_header[10] = step; // step of the optimization\n    // model rng state, start at 20 to leave some padding\n    *((unsigned long long*)&state_header[20]) = model->rng_state; // random number generator state\n    *((unsigned long long*)&state_header[22]) = model->rng_state_last_update; // last gpt2_update\n    // dataloader state, start at 30 to leave some padding\n    *((size_t*)&state_header[30]) = loader->current_shard_idx; // shard of the dataset\n    *((size_t*)&state_header[32]) = loader->current_sample_idx; // position in shard\n    fwriteCheck(state_header, sizeof(int), 256, state_file);\n\n    // write AdamW m, v, and master_weights here (they are all float)\n    size_t shard_num_parameters = multi_gpu_config.shard_num_parameters;\n    device_to_file(state_file, model->m_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);\n    device_to_file(state_file, model->v_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);\n    if(model->use_master_weights) {\n        device_to_file(state_file, model->master_weights, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);\n    }\n\n    // write dataloader state if we are using the Permuted version of it\n    if (loader->should_shuffle) {\n        fwriteCheck(&loader->glob_result.gl_pathc, sizeof(size_t), 1, state_file);  // number of shards\n        fwriteCheck(loader->shard_indices, sizeof(int), loader->glob_result.gl_pathc, state_file);\n        fwriteCheck(&loader->shard_num_samples, sizeof(size_t), 1, state_file);\n        fwriteCheck(loader->intra_shard_indices, sizeof(int), loader->shard_num_samples, state_file);\n        fwriteCheck(&loader->shuffle_rng, sizeof(mt19937_state), 1, state_file);\n    }\n    fcloseCheck(state_file);\n}\n\nvoid load_state(int* step, GPT2* model, DataLoader* loader, const char* filename) {\n    FILE *state_file = fopenCheck(filename, \"rb\");\n    int state_header[256];\n    freadCheck(state_header, sizeof(int), 256, state_file);\n    assert(state_header[0] == 20240527); // magic number\n    assert(state_header[1] == 1); // version number\n    assert(state_header[2] == multi_gpu_config.num_processes); // number of processes\n    assert(state_header[3] == multi_gpu_config.process_rank); // rank of this process\n    int use_master_weights = state_header[4];  // whether we're using fp32 master weights\n    int should_shuffle = state_header[5]; // shuffle state of the dataloader\n    *step = state_header[10]; // step of the optimization\n    model->rng_state = *((unsigned long long*)&state_header[20]); // random number generator state\n    model->rng_state_last_update = *((unsigned long long*)&state_header[22]); // last gpt2_update\n    size_t current_shard_idx = *((size_t*)&state_header[30]); // shard index\n    size_t current_sample_idx = *((size_t*)&state_header[32]); // position in shard\n\n    // read AdamW m, v, master_weights (they are all float)\n    // allocate all the needed memory as necessary\n    size_t shard_num_parameters = multi_gpu_config.shard_num_parameters;\n    if(use_master_weights == 1 && !model->use_master_weights) {\n        printf0(\"Warning: Master weights are present in state, but not enabled for current run.\");\n    } else if (use_master_weights == 0 && model->use_master_weights) {\n        printf0(\"Error: Master weights requested, but not present in state file.\");\n        exit(EXIT_FAILURE);\n    }\n\n    model->init_state = false;      // we just got the state from file, no need to do first-touch init\n    assert(model->m_memory != nullptr);\n    assert(model->v_memory != nullptr);\n    file_to_device(model->m_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);\n    file_to_device(model->v_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);\n    if(model->use_master_weights) {\n        assert(model->master_weights != nullptr);\n        file_to_device(model->master_weights, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);\n        // restore weights from the master weights using the RNG state before last weight update\n        model->rng_state = model->rng_state_last_update;\n        gpt2_update(model, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0, &multi_gpu_config, /* init_from_master_only*/ true);\n        model->rng_state = *((unsigned long long*)&state_header[20]); // use final RNG state from checkpoint after this\n    }\n\n    // revive the DataLoader object and its state\n    loader->should_shuffle = should_shuffle;\n    if (should_shuffle == 1) {\n        // ensure the number of shards matches\n        size_t glob_result_gl_pathc;\n        freadCheck(&glob_result_gl_pathc, sizeof(size_t), 1, state_file);\n        assert(glob_result_gl_pathc == loader->glob_result.gl_pathc);\n        // read the shard indices\n        loader->shard_indices = (int*)mallocCheck(loader->glob_result.gl_pathc * sizeof(int));\n        freadCheck(loader->shard_indices, sizeof(int), loader->glob_result.gl_pathc, state_file);\n        // ensure the number of samples matches\n        size_t shard_num_samples;\n        freadCheck(&shard_num_samples, sizeof(size_t), 1, state_file);\n        assert(shard_num_samples == loader->shard_num_samples);\n        // read the intra-shard indices\n        loader->intra_shard_indices = (int*)mallocCheck(loader->shard_num_samples * sizeof(int));\n        freadCheck(loader->intra_shard_indices, sizeof(int), loader->shard_num_samples, state_file);\n        // read the shuffle rng state\n        freadCheck(&loader->shuffle_rng, sizeof(mt19937_state), 1, state_file);\n    }\n    dataloader_resume(loader, current_shard_idx, current_sample_idx);\n\n    // all done, close state file\n    fcloseCheck(state_file);\n}\n\nvoid write_checkpoint(const char* output_log_dir, int step, GPT2* model, DataLoader* train_loader, MultiGpuConfig* multi_gpu_config) {\n    // a checkpoint contains: model weights, optimizer/dataloader state, and a DONE file\n    printf0(\"Writing checkpoint at step %d\\n\", step);\n    int rank = multi_gpu_config->process_rank;\n    // only rank 0 writes the model file because it is the same across all ranks\n    if (rank == 0) {\n        snprintf(filename_buffer, sizeof(filename_buffer), \"%s/model_%08d.bin\", output_log_dir, step);\n        gpt2_write_to_checkpoint(model, filename_buffer);\n    }\n    // all ranks write their state file\n    snprintf(filename_buffer, sizeof(filename_buffer), \"%s/state_%08d_%05d.bin\", output_log_dir, step, rank);\n    save_state(filename_buffer, step, model, train_loader);\n    // DONE file is a signal that this checkpoint as a whole is complete\n    multi_gpu_barrier(multi_gpu_config);\n    if (rank == 0) {\n        snprintf(filename_buffer, sizeof(filename_buffer), \"%s/DONE_%08d\", output_log_dir, step);\n        FILE* done_file = fopenCheck(filename_buffer, \"w\");\n        fcloseCheck(done_file);\n    }\n}\n\nvoid delete_checkpoint(const char* output_log_dir, int step, MultiGpuConfig* multi_gpu_config) {\n    // mirrors write_checkpoint function, cleans up checkpoint from disk\n    printf0(\"Deleting checkpoint at step %d\\n\", step);\n    int rank = multi_gpu_config->process_rank;\n    if (rank == 0) {\n        snprintf(filename_buffer, sizeof(filename_buffer), \"%s/model_%08d.bin\", output_log_dir, step);\n        remove(filename_buffer);\n    }\n    snprintf(filename_buffer, sizeof(filename_buffer), \"%s/state_%08d_%05d.bin\", output_log_dir, step, rank);\n    remove(filename_buffer);\n    if (rank == 0) {\n        snprintf(filename_buffer, sizeof(filename_buffer), \"%s/DONE_%08d\", output_log_dir, step);\n        remove(filename_buffer);\n    }\n}\n\n#ifndef TESTING\n// if we are TESTING (see test_gpt2.cu), we'll skip everything below this point\n\n// ----------------------------------------------------------------------------\n// training resumption logic, very useful when jobs crash once in a while\n// the goal is that we can resume optimization from any checkpoint, bit-perfect\n// note that \"state\" refers to things not already saved in the model checkpoint file\n\n// ----------------------------------------------------------------------------\n// CLI, poor man's argparse\n// (all single letters have been claimed now)\n\nvoid error_usage() {\n    fprintf(stderr, \"Usage:   ./train_gpt2cu [options]\\n\");\n    fprintf(stderr, \"Options:\\n\");\n    // file system input / output\n    fprintf(stderr, \"  -i <string> train data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_train.bin)\\n\");\n    fprintf(stderr, \"  -j <string> val data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_val.bin)\\n\");\n    fprintf(stderr, \"  -e <string> input .bin filename or descriptor, see code comments as docs. (default = gpt2_124M_bf16.bin)\\n\");\n    fprintf(stderr, \"  -o <string> output log dir (default = NULL, no logging)\\n\");\n    fprintf(stderr, \"  -lg <int>   log gpu info every x steps (default = -1; disabled)\\n\");\n    fprintf(stderr, \"  -n <int>    write optimization checkpoints every how many steps? (default 0, don't)\\n\");\n    fprintf(stderr, \"  -nk <int>   max number of checkpoints to keep in the directory, removing old ones (0 = disable, default)\\n\");\n    fprintf(stderr, \"  -nm <int>   every how many step checkpoints are considered major? major checkpoints never get deleted.\\n\");\n    fprintf(stderr, \"  -y <int>    resume optimization found inside output log dir? (0=restart/overwrite, 1=resume/append)\\n\");\n    // token layout for each step of the optimization\n    fprintf(stderr, \"  -b <int>    (per-GPU, micro) batch size B (default = 4)\\n\");\n    fprintf(stderr, \"  -t <int>    sequence length T (default = 1024)\\n\");\n    fprintf(stderr, \"  -d <int>    total desired batch size (default = B * T * num_processes, i.e. no grad accumulation\\n\");\n    // workload (number of steps)\n    fprintf(stderr, \"  -x <int>    max_steps of optimization to run (-1 (default) = disable, run 1 epoch)\\n\");\n    // optimization\n    fprintf(stderr, \"  -k <string> learning rate scheduler (default = cosine)\\n\");\n    fprintf(stderr, \"  -l <float>  learning rate (default = 3e-4f)\\n\");\n    fprintf(stderr, \"  -u <int>    learning rate warmup iterations (default = 0, no warmup)\\n\");\n    fprintf(stderr, \"  -q <float>  learning rate decay: final fraction, at end of training (default = 1.0 (no decay))\\n\");\n    fprintf(stderr, \"  -c <float>  weight decay (default = 0.0f)\\n\");\n    fprintf(stderr, \"  -sl <float> outlier stability: skip update if loss goes above this in zscore (0.0f=off)\\n\");\n    fprintf(stderr, \"  -sg <float> outlier stability: skip update if grad_norm goes above this in zscore (0.0f=off)\\n\");\n    // evaluation\n    fprintf(stderr, \"  -v <int>    val_loss_every, how often we evaluate val loss (default = 20)\\n\");\n    fprintf(stderr, \"  -m <int>    val_max_steps, up to how many val batches to estimate val loss? (default = 20)\\n\");\n    fprintf(stderr, \"  -s <int>    sample_every, how often we inference the model (default = 20)\\n\");\n    fprintf(stderr, \"  -g <int>    genT, how many steps of inference we do (default = 64)\\n\");\n    fprintf(stderr, \"  -h <int>    hellaswag eval run? (default = 0)\\n\");\n    // debugging\n    fprintf(stderr, \"  -a <int>    overfit a single batch? 0/1. useful for debugging\\n\");\n    // numerics\n    fprintf(stderr, \"  -f <int>    enable_tf32 override (default: 1, set to 0 to disable tf32)\\n\");\n    fprintf(stderr, \"  -w <int>    keep f32 copy of weights for the optimizer? (default: 1)\\n\");\n    fprintf(stderr, \"  -ge <int>   gelu fusion: 0=none, 1=forward, 2=forward+backward (default: 2 for >=SM90, 0 for older GPUs)\\n\");\n    // memory management\n    fprintf(stderr, \"  -z <int>    zero_stage, Zero Optimization Stage, 0,1,2,3 (default = 0)\\n\");\n    fprintf(stderr, \"  -r <int>    recompute: less memory but less speed. (default = 1), 0|1|2 = none,gelu,gelu+ln\\n\");\n    // multi-node settings\n    fprintf(stderr, \"  -pn <int>    num_processes (default = 1)\\n\");\n    fprintf(stderr, \"  -pr <int>    process_rank (default = 0)\\n\");\n    fprintf(stderr, \"  -pg <int>    gpus_per_node (default = 8)\\n\");\n    fprintf(stderr, \"  -pm <string> nccl_init_method: tcp,fs,mpi (default = mpi)\\n\");\n    fprintf(stderr, \"  -ps <string> server_ip - used only when nccl_init_method is tcp (default = -1)\\n\");\n    fprintf(stderr, \"  -pp <string> fs_path - used only when nccl_init_method is fs (default = /tmp)\\n\");\n    exit(EXIT_FAILURE);\n}\n\n// ----------------------------------------------------------------------------\n// main training loop\nint main(int argc, char *argv[]) {\n    // read in the (optional) command line arguments\n    const char* train_data_pattern = \"dev/data/tinyshakespeare/tiny_shakespeare_train.bin\";\n    const char* val_data_pattern = \"dev/data/tinyshakespeare/tiny_shakespeare_val.bin\";\n    const char* load_filename = \"gpt2_124M_bf16.bin\"; // bf16 weights of the model\n    const char* lr_scheduler_type = \"cosine\";\n    const char* output_log_dir = NULL;\n    int checkpoint_every = 0; // write checkpoints every how many steps?\n    int checkpoints_keep = 0; // how long checkpoint history do we keep? (in units of checkpoints)\n    int major_checkpoint_every = 0; // major checkpoints never get deleted when maintaining history\n    int resume = 0; // resume the optimization, if one is found inside output_log_dir?\n    int B = 4; // batch size\n    int T = 1024; // sequence length max\n    int total_batch_size = -1; // will be calculated down below later, if not provided\n    float learning_rate = 3e-4f;\n    int log_gpu_every = -1;\n    int warmup_iterations = 0;\n    float final_learning_rate_frac = 1.0f; // final fraction of learning rate, at end of training\n    float weight_decay = 0.0f;\n    float skip_update_lossz = 0.0f; // skip update if loss goes above this in zscore\n    float skip_update_gradz = 0.0f; // skip update if grad_norm goes above this in zscore\n    int val_loss_every = 20; // every how many steps do we eval validation loss?\n    int val_max_steps = 20; // how many batches max do we eval for validation loss?\n    int sample_every = 20; // every how many steps to do inference?\n    int genT = 64; // number of steps of inference we will do\n    int overfit_single_batch = 0; // useful for debugging, 1 = only load a single data batch once\n    int max_steps = -1;\n    int override_enable_tf32 = 1;\n    int use_master_weights = 1;\n    int gelu_fusion = -1; // 0 = none, 1 = forward, 2 = forward+backward (-1 => per-GPU default)\n    int recompute = 1; // recompute during backward setting, 0 = none, 1 = recompute gelu\n    int zero_stage = 0; // Zero Optimization Stage for Multi-GPU training\n    int hellaswag_eval = 0;\n    // multi-node settings\n    int num_processes = 1;  // this should be set by the slurm environment\n    int process_rank = 0;  // this should be set by the slurm environment\n    int gpus_per_node = 8;  // this should be set by the slurm environment\n    char nccl_init_method[256] = \"mpi\";  // \"tcp\" or \"fs\" or \"mpi\"\n    char server_ip[256] = \"\";  // used if init_method set to \"tcp\" -> set to your server ip address\n    char fs_path[256] = \"\";  // used if init_method set to \"fs\" -> set to a shared filesystem path\n    for (int i = 1; i < argc; i+=2) {\n        if (i + 1 >= argc) { error_usage(); } // must have arg after flag\n        if (argv[i][0] != '-') { error_usage(); } // must start with dash\n        if (!(strlen(argv[i]) == 2 || strlen(argv[i]) == 3)) { error_usage(); } // must be -x[y] (one dash, one or two letters)\n        // read in the args\n        if (argv[i][1] == 'i') { train_data_pattern = argv[i+1]; }\n        else if (argv[i][1] == 'j') { val_data_pattern = argv[i+1]; }\n        else if (argv[i][1] == 'e') { load_filename = argv[i+1]; }\n        else if (argv[i][1] == 'o') { output_log_dir = argv[i+1]; }\n        else if (argv[i][1] == 'n' && argv[i][2] == '\\0') { checkpoint_every = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'y') { resume = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'b') { B = atoi(argv[i+1]); } // Per-GPU (micro) batch size\n        else if (argv[i][1] == 't') { T = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'd') { total_batch_size = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'l' && argv[i][2] == '\\0') { learning_rate = atof(argv[i+1]); }\n        else if (argv[i][1] == 'l' && argv[i][2] == 'g') { log_gpu_every = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'u') { warmup_iterations = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'q') { final_learning_rate_frac = atof(argv[i+1]); }\n        else if (argv[i][1] == 'c') { weight_decay = atof(argv[i+1]); }\n        else if (argv[i][1] == 'x') { max_steps = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'v') { val_loss_every = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'm') { val_max_steps = atoi(argv[i+1]); }\n        else if (argv[i][1] == 's' && argv[i][2] == '\\0') { sample_every = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'g' && argv[i][2] == 'e') { gelu_fusion = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'g') { genT = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'a') { overfit_single_batch = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'f') { override_enable_tf32 = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'w') { use_master_weights = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'z') { zero_stage = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'r') { recompute = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'h') { hellaswag_eval = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'k') { lr_scheduler_type = argv[i+1]; }\n        else if (argv[i][1] == 'p' && argv[i][2] == 'i') { strcpy(nccl_init_method, argv[i+1]); }\n        else if (argv[i][1] == 'p' && argv[i][2] == 'f') { strcpy(fs_path, argv[i+1]); }\n        else if (argv[i][1] == 'p' && argv[i][2] == 's') { strcpy(server_ip, argv[i+1]); }\n        else if (argv[i][1] == 'p' && argv[i][2] == 'n') { num_processes = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'p' && argv[i][2] == 'r') { process_rank = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'p' && argv[i][2] == 'g') { gpus_per_node = atoi(argv[i+1]); }\n        else if (argv[i][1] == 's' && argv[i][2] == 'l') { skip_update_lossz = atof(argv[i+1]); }\n        else if (argv[i][1] == 's' && argv[i][2] == 'g') { skip_update_gradz = atof(argv[i+1]); }\n        else if (argv[i][1] == 'n' && argv[i][2] == 'k') { checkpoints_keep = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'n' && argv[i][2] == 'm') { major_checkpoint_every = atoi(argv[i+1]); }\n        else { error_usage(); }\n    }\n\n    multi_gpu_config = multi_gpu_config_init(num_processes, process_rank, gpus_per_node, server_ip, fs_path, nccl_init_method);\n    common_start(override_enable_tf32, false); // common init code for train/test/profile\n\n    // should do a bit more error checking here\n    assert(warmup_iterations >= 0);\n    if (output_log_dir != NULL) {\n        assert(strlen(output_log_dir) < 400); // careful bunch of hardcoded snprintf around this\n    }\n    int tokens_per_fwdbwd = B * T * multi_gpu_config.num_processes; // one micro-batch processes this many tokens\n    // calculate sensible default for total batch size as assuming no gradient accumulation\n    if (total_batch_size == -1) { total_batch_size = tokens_per_fwdbwd; }\n    // in the future, we might want to set gelu fusion to 2 for SM90+ and 0 for other GPUs\n    if (gelu_fusion == -1) { gelu_fusion = 0; } // (deviceProp.major >= 9) ? 2 : 0; } // in gpt2_init_common for test_gpt2cu...\n    // calculate the number of gradient accumulation steps from the desired total batch size\n    assert(total_batch_size % tokens_per_fwdbwd == 0);\n    int grad_accum_steps = total_batch_size / tokens_per_fwdbwd;\n    // if we're only overfitting a single batch for debugging, let's overfit the first batch\n    // from val instead of train split, because val is smaller and faster. (train_gpt2.py does the same)\n    if (overfit_single_batch == 1) { train_data_pattern = val_data_pattern; }\n    printf0(\"+-----------------------+----------------------------------------------------+\\n\");\n    printf0(\"| Parameter             | Value                                              |\\n\");\n    printf0(\"+-----------------------+----------------------------------------------------+\\n\");\n    printf0(\"| train data pattern    | %-50s |\\n\", train_data_pattern);\n    printf0(\"| val data pattern      | %-50s |\\n\", val_data_pattern);\n    printf0(\"| output log dir        | %-50s |\\n\", output_log_dir == NULL ? \"NULL\" : output_log_dir);\n    printf0(\"| checkpoint_every      | %-50d |\\n\", checkpoint_every);\n    printf0(\"| resume                | %-50d |\\n\", resume);\n    printf0(\"| micro batch size B    | %-50d |\\n\", B);\n    printf0(\"| sequence length T     | %-50d |\\n\", T);\n    printf0(\"| total batch size      | %-50d |\\n\", total_batch_size);\n    printf0(\"| LR scheduler          | %-50s |\\n\", lr_scheduler_type);\n    printf0(\"| learning rate (LR)    | %-50e |\\n\", learning_rate);\n    printf0(\"| warmup iterations     | %-50d |\\n\", warmup_iterations);\n    printf0(\"| final LR fraction     | %-50e |\\n\", final_learning_rate_frac);\n    printf0(\"| weight decay          | %-50e |\\n\", weight_decay);\n    printf0(\"| skip update lossz     | %-50f |\\n\", skip_update_lossz);\n    printf0(\"| skip update gradz     | %-50f |\\n\", skip_update_gradz);\n    printf0(\"| max_steps             | %-50d |\\n\", max_steps);\n    printf0(\"| val_loss_every        | %-50d |\\n\", val_loss_every);\n    printf0(\"| val_max_steps         | %-50d |\\n\", val_max_steps);\n    printf0(\"| sample_every          | %-50d |\\n\", sample_every);\n    printf0(\"| genT                  | %-50d |\\n\", genT);\n    printf0(\"| overfit_single_batch  | %-50d |\\n\", overfit_single_batch);\n    printf0(\"| use_master_weights    | %-50s |\\n\", use_master_weights ? \"enabled\" : \"disabled\");\n    printf0(\"| gelu_fusion           | %-50d |\\n\", gelu_fusion);\n    printf0(\"| recompute             | %-50d |\\n\", recompute);\n    printf0(\"+-----------------------+----------------------------------------------------+\\n\");\n    const char* precision_str = (PRECISION_MODE == PRECISION_FP32)\n                              ? (cublas_compute == CUBLAS_COMPUTE_32F_FAST_TF32 ? \"TF32\" : \"FP32\")\n                              : (PRECISION_MODE == PRECISION_FP16 ? \"FP16\" : \"BF16\");\n    printf0(\"| device                | %-50s |\\n\", deviceProp.name);\n    printf0(\"| peak TFlops           | %-50.1f |\\n\", get_flops_promised(deviceProp.name, PRECISION_MODE));\n    printf0(\"| precision             | %-50s |\\n\", precision_str);\n    printf0(\"+-----------------------+----------------------------------------------------+\\n\");\n\n    // figure out if we are going to be resuming the optimization\n    int resuming = 0;\n    // find the DONE file with the highest step count\n    int resume_max_step = find_max_step(output_log_dir);\n    if (resume == 1) { // is -y 1 resume flag set?\n        assert(output_log_dir != NULL);\n        if (resume_max_step != -1) {\n            resuming = 1; // -y 1 is set, and we found a checkpoint we can resume from\n            snprintf(filename_buffer, sizeof(filename_buffer), \"%s/model_%08d.bin\", output_log_dir, resume_max_step);\n        }\n    }\n\n    // build the GPT-2 model\n    GPT2 model;\n    gpt2_init_common(&model);\n    if (resuming == 1) {\n        // if `-y 1` was set, then we are resuming from the latest checkpoint\n        // if we are using master weights, we'll init them later inside load_state()\n        bool weight_init = !use_master_weights;\n        gpt2_build_from_checkpoint(&model, filename_buffer, weight_init);\n    } else if (ends_with_bin(load_filename)) {\n        // otherwise, if this is a .bin file, we assume it's a model, let's init from it\n        gpt2_build_from_checkpoint(&model, load_filename);\n    } else {\n        // if it's not .bin, it could be a \"special descriptor\". This descriptor is used to\n        // construct GPT-2 / GPT-3 models in a convenient format. See the function for docs.\n        gpt_build_from_descriptor(&model, load_filename);\n    }\n\n    model.use_master_weights = use_master_weights;\n    model.gelu_fusion = gelu_fusion;\n    model.recompute = recompute;\n    printf0(\"| weight init method    | %-50s |\\n\", resuming == 1 ? \"intermediate checkpoint\" : load_filename);\n    printf0(\"| max_sequence_length T | %-50d |\\n\", model.config.max_seq_len);\n    printf0(\"| vocab_size V          | %-50d |\\n\", model.config.vocab_size);\n    printf0(\"| padded_vocab_size Vp  | %-50d |\\n\", model.config.padded_vocab_size);\n    printf0(\"| num_layers L          | %-50d |\\n\", model.config.num_layers);\n    printf0(\"| num_heads NH          | %-50d |\\n\", model.config.num_heads);\n    printf0(\"| channels C            | %-50d |\\n\", model.config.channels);\n    printf0(\"| num_parameters        | %-50zu |\\n\", model.num_parameters);\n    printf0(\"+-----------------------+----------------------------------------------------+\\n\");\n\n    // build DataLoaders for both train and val\n    int permute_train_loader = (overfit_single_batch == 1) ? 0 : 1;\n    DataLoader train_loader, val_loader;\n    dataloader_init(&train_loader, train_data_pattern, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes, permute_train_loader);\n    dataloader_init(&val_loader, val_data_pattern, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes, 0);\n    // figure out the number of training steps we will run for\n    int train_num_batches = max_steps; // passed in from command line\n    if (train_num_batches == -1) {\n        // sensible default is to train for exactly one epoch\n        size_t ntok = train_loader.num_tokens;\n        // the number of (outer loop) steps each process should take for us to reach one epoch\n        train_num_batches = ntok / total_batch_size;\n    }\n    // figure out the number of validation steps to run for\n    int val_num_batches = val_max_steps; // passed in from command line\n    if (val_num_batches == -1) {\n        // sensible default is to evaluate the full validation split\n        size_t ntok = val_loader.num_tokens;\n        // note that unlike the training loop, there is no gradient accumulation inner loop here\n        val_num_batches = ntok / tokens_per_fwdbwd;\n    }\n    printf0(\"| train_num_batches     | %-50d |\\n\", train_num_batches);\n    printf0(\"| val_num_batches       | %-50d |\\n\", val_num_batches);\n    printf0(\"+-----------------------+----------------------------------------------------+\\n\");\n\n    // build an EvalLoader for HellaSwag\n    EvalLoader eval_loader;\n    const char* hellaswag_path = \"dev/data/hellaswag/hellaswag_val.bin\";\n    const bool hellaswag_available = access(hellaswag_path, F_OK) == 0;\n    const bool run_hellaswag = hellaswag_eval && hellaswag_available;\n    if (run_hellaswag) {\n        evalloader_init(&eval_loader, hellaswag_path, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes);\n    }\n    printf0(\"| run hellaswag         | %-50s |\\n\", run_hellaswag ? \"yes\" : \"no\");\n    printf0(\"+-----------------------+----------------------------------------------------+\\n\");\n\n    // pretty print in a table the multi-gpu configuration as well\n    set_zero_configs(&multi_gpu_config, zero_stage, model.num_parameters);\n    printf0(\"| num_processes         | %-50d |\\n\", multi_gpu_config.num_processes);\n    printf0(\"| zero_stage            | %-50d |\\n\", multi_gpu_config.zero_stage);\n    printf0(\"+-----------------------+----------------------------------------------------+\\n\");\n\n    // prints outside of pretty table to here and below\n    if (!hellaswag_available) {\n        printf0(\"HellaSwag eval not found at %s, skipping its evaluation\\n\", hellaswag_path);\n        printf0(\"You can run `python dev/data/hellaswag.py` to export and use it with `-h 1`.\\n\");\n    }\n    // more prints related to allocations from gpt2_build_from_checkpoint down here to not mess up our table above\n    printf0(\"num_parameters: %zu => bytes: %zu\\n\", model.num_parameters, model.num_parameters_bytes);\n    printf0(\"allocated %d MiB for model parameters\\n\", (int)round(model.num_parameters_bytes / (1024 * 1024)));\n    // few more prints for gradient accumulation math up above\n    printf0(\"batch_size B=%d * seq_len T=%d * num_processes=%d and total_batch_size=%d\\n\",\n            B, T, multi_gpu_config.num_processes, total_batch_size);\n    printf0(\"=> setting grad_accum_steps=%d\\n\", grad_accum_steps);\n\n    // set up logging\n    if (multi_gpu_config.process_rank == 0) { create_dir_if_not_exists(output_log_dir); }\n    Logger logger;\n    logger_init(&logger, output_log_dir, multi_gpu_config.process_rank, resume);\n\n    // set up the Tokenizer\n    Tokenizer tokenizer;\n    tokenizer_init(&tokenizer, \"gpt2_tokenizer.bin\");\n\n    // set up learning rate scheduler\n    LearningRateScheduler lr_scheduler;\n    lr_scheduler_init(&lr_scheduler, lr_scheduler_type, learning_rate,\n                      warmup_iterations, train_num_batches, final_learning_rate_frac);\n\n    // some memory for generating samples from the model\n    int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int));\n    floatX* cpu_logits_raw = (floatX*)mallocCheck(model.config.vocab_size * sizeof(floatX));\n    float*  cpu_logits = (float*)mallocCheck(model.config.vocab_size * sizeof(float));\n\n    // if we found a checkpoint to resume from, load the optimization state\n    int step = 0;\n    gpt2_allocate_state(&model, B, T);\n    if (resuming == 1) {\n        snprintf(filename_buffer, sizeof(filename_buffer), \"%s/state_%08d_%05d.bin\", output_log_dir, resume_max_step, multi_gpu_config.process_rank);\n        load_state(&step, &model, &train_loader, filename_buffer);\n    }\n\n    // init an OutlierDetector the training loss\n    OutlierDetector loss_outlier_detector, grad_norm_outlier_detector;\n    init_detector(&loss_outlier_detector);\n    init_detector(&grad_norm_outlier_detector);\n\n    // do some checks here before we kick off training\n    // cross-check the desired sequence length T with the model's max sequence length\n    if (T < model.config.max_seq_len) {\n        printf0(\"!!!!!!!!\\n\");\n        printf0(\"WARNING:\\n\");\n        printf0(\"- The training sequence length is: T=%d (set with -t)\\n\", T);\n        printf0(\"- The model's max sequence length is: max_seq_len=%d\\n\", model.config.max_seq_len);\n        printf0(\"You are attempting to train with a sequence length shorter than the model's max.\\n\");\n        printf0(\"This will lead to unused parameters in the wpe position embedding weights.\\n\");\n        printf0(\"If you know what you're doing you can ignore this warning.\\n\");\n        printf0(\"If you're like ???, you are most likely misconfiguring your training run.\\n\");\n        printf0(\"---> HINT: If you're training GPT-2 use -t 1024. If GPT-3, use -t 2048.\\n\");\n        printf0(\"!!!!!!!!\\n\");\n    }\n    // in any case, this must be true or we'd index beyond the model's wpe (position embedding table)\n    assert(T <= model.config.max_seq_len);\n\n    // train\n    cudaEvent_t start, end;\n    cudaCheck(cudaEventCreate(&start));\n    cudaCheck(cudaEventCreate(&end));\n    cudaCheck(cudaProfilerStart());\n    double total_sum_iteration_time_s = 0.0;\n    float ema_tokens_per_second = 0.0f;\n    for (; step <= train_num_batches; step++) {\n        NvtxRange step_range(\"Train step\", step);\n\n        int last_step = step == train_num_batches;\n\n        // once in a while estimate the validation loss (all processes collaborate)\n        if (step % val_loss_every == 0 || last_step) {\n            NvtxRange validation_range(\"validation\");\n            float val_loss = 0.0f;\n            dataloader_reset(&val_loader);\n            for (int i = 0; i < val_num_batches; i++) {\n                dataloader_next_batch(&val_loader);\n                val_loss += gpt2_validate(&model, val_loader.inputs, val_loader.targets, B, T);\n            }\n            val_loss /= val_num_batches;\n            val_loss = multi_gpu_cpu_float_sum(val_loss, &multi_gpu_config) / multi_gpu_config.num_processes;\n            printf0(\"val loss %f\\n\", val_loss);\n            logger_log_val(&logger, step, val_loss);\n        }\n\n        // once in a while estimate HellaSwag accuracy (all processes collaborate)\n        if (run_hellaswag &&\n           ((step > 0 && step % val_loss_every == 0) || last_step)) {\n            NvtxRange evaluation_range(\"evaluation\");\n            float eval_acc_norm = 0.0f;\n            evalloader_reset(&eval_loader);\n            for (int i = 0; i < eval_loader.num_batches; i++) {\n                if (i % 10 == 0) { printf(\"evaluating HellaSwag: %d/%d\\r\", i, eval_loader.num_batches); }\n                evalloader_next_batch(&eval_loader);\n                gpt2_validate(&model, eval_loader.inputs, eval_loader.targets, B, T);\n                int correct = evalloader_stat_losses(&eval_loader, model.cpu_losses);\n                eval_acc_norm += (float)correct;\n            }\n            // careful because not all ranks may have the exact same allocation of number of examples\n            eval_acc_norm = multi_gpu_cpu_float_sum(eval_acc_norm, &multi_gpu_config);\n            printf0(\"HellaSwag: %d/%d = %f\\n\", (int)eval_acc_norm, eval_loader.num_examples, eval_acc_norm / eval_loader.num_examples);\n            logger_log_eval(&logger, step, eval_acc_norm / eval_loader.num_examples);\n        }\n\n        // once in a while do model inference to print generated text (only rank 0)\n        if (multi_gpu_config.process_rank == 0 && sample_every > 0 &&\n           (step > 0 && (step % sample_every) == 0 || last_step)) {\n            NvtxRange generation_range(\"generation\");\n            unsigned long long sample_rng_state = 1337;\n            // fill up gen_tokens with the <|endoftext|> token, which kicks off the generation\n            int eot_token = tokenizer.eot_token;\n            for(int i = 0; i < B * T; ++i) {\n                gen_tokens[i] = eot_token;\n            }\n            // now sample from the model autoregressively\n            printf(\"generating:\\n---\\n\");\n            for (int t = 1; t < genT; t++) {\n                NvtxRange generation_range(\"Generation step\", t);\n                // we try not to be too wasteful for inference by not calculating all of B,T\n                // Using a smaller B is always bit-for-bit identical, but T is more tricky\n                // for non-CUDNN, we need to make sure the attention buffer is memset to 0\n                // for cuDNN, it might suddenly decide to use a slightly different algorithm...\n                // on cuDNN 9.2.1 with cuDNN FrontEnd 1.5.2, T >= 256 seems bit-for-bit identical\n                // (but even if it wasn't fully identical that's probably not the end of the world)\n                // note this is still somewhat wasteful because we don't have a KV cache!\n                gpt2_forward(&model, gen_tokens, 1, CEIL_DIV(t, min(T,256)) * min(T,256));\n                // get the V-dimensional vector probs[0, t-1, :]\n                floatX* logits = model.acts.output + (t - 1) * model.config.padded_vocab_size;\n                // move probs back to CPU and sample (note we only move the first vocab_size logits, ignoring the padding)\n                cudaCheck(cudaMemcpy(cpu_logits_raw, logits, model.config.vocab_size * sizeof(floatX), cudaMemcpyDeviceToHost));\n                // convert to FP32 into cpu_logits (this does nothing useful if floatX == float)\n                for (int i = 0; i < model.config.vocab_size; i++) {\n                    cpu_logits[i] = (float)cpu_logits_raw[i];\n                }\n                // sample the next token\n                float coin = random_f32(&sample_rng_state);\n                int next_token = sample_softmax(cpu_logits, model.config.vocab_size, coin);\n                gen_tokens[t] = next_token;\n                // print the generated token, either using the Tokenizer or a fallback\n                if (tokenizer.init_ok) {\n                    const char* token_str = tokenizer_decode(&tokenizer, next_token);\n                    safe_printf(token_str);\n                } else {\n                    // fall back to printing the token id\n                    printf(\"%d \", next_token);\n                }\n                fflush(stdout);\n            }\n            printf(\"\\n---\\n\");\n        }\n\n        // once in a while checkpoint the optimization state (all ranks)\n        if ((checkpoint_every > 0 && output_log_dir != NULL && resuming == 0) &&\n            ((step > 0 && step % checkpoint_every == 0) || last_step)) {\n            // writes model .bin file, state .bin files, and DONE file for step\n            write_checkpoint(output_log_dir, step, &model, &train_loader, &multi_gpu_config);\n            // we only keep checkpoints_keep checkpoints on disk to save space\n            // so now that we wrote a new checkpoint, delete one old one (unless it is a \"major\" checkpoint)\n            // we only do this is checkpoint keeping is turned on (checkpoints_keep > 0)\n            int step_delete = step - checkpoints_keep * checkpoint_every;\n            if (checkpoints_keep > 0 && step_delete > 0 &&\n               (major_checkpoint_every == 0 || step_delete % major_checkpoint_every != 0)\n                ) {\n                delete_checkpoint(output_log_dir, step_delete, &multi_gpu_config);\n            }\n        }\n        resuming = 0;\n\n        // bit confusing: we want to make sure to eval and sample on 0th iteration\n        // but also after the very last iteration. so we loop for step <= train_num_batches\n        // instead of just < train_num_batches (one extra due to <=), only to do\n        // the validation/sampling one last time, and then we break right here as we're done.\n        if (last_step) { break; }\n\n        // --------------- TRAINING SECTION BEGIN -----------------\n        if (overfit_single_batch == 1) {\n            // if we are trying to overfit a single batch, we reset the loader here\n            dataloader_reset(&train_loader);\n        }\n        // do one training step, doing forward/backward/update on total_batch_size tokens\n        cudaCheck(cudaEventRecord(start));\n        // gradient and loss accumulation loop over micro-batches\n        for (int micro_step = 0; micro_step < grad_accum_steps; micro_step++) {\n            // fetch the next data batch\n            dataloader_next_batch(&train_loader);\n            // forward pass. note that we pass in grad_accum_steps, which scales down the loss\n            gpt2_forward(&model, train_loader.inputs, B, T);\n            // backward pass. all model params accumulate gradients with += inside this inner loop\n            gpt2_backward_and_reduce(&model, train_loader.inputs, train_loader.targets, grad_accum_steps, micro_step);\n        }\n        float zloss = (float)(update_detector(&loss_outlier_detector, (double)model.mean_loss)); // loss z-score\n        // fetch the next learning rate\n        float step_learning_rate = get_learning_rate(&lr_scheduler, step);\n        // calculate the gradient norm and how much we wish to scale the gradient\n        float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config);\n        float zgrad = (float)(update_detector(&grad_norm_outlier_detector, (double)grad_norm)); // grad z-score\n        // update the model parameters\n        if (isfinite(zloss) && skip_update_lossz != 0.0f && zloss > skip_update_lossz) {\n            printf0(\"skipping update due to loss z-score of %f\\n\", zloss);\n        } else if (isfinite(zgrad) && skip_update_gradz != 0.0f && zgrad > skip_update_gradz) {\n            printf0(\"skipping update due to grad z-score of %f\\n\", zgrad);\n        } else {\n            // clip the gradient norm to a maximum value\n            float grad_clip = 1.0f;\n            float grad_scale = (grad_norm > grad_clip) ? grad_clip / grad_norm : 1.0f;\n            gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, grad_scale, step+1, &multi_gpu_config);\n        }\n        cudaCheck(cudaEventRecord(end));\n        cudaCheck(cudaEventSynchronize(end)); // wait for the end event to finish to get correct timings\n        // --------------- TRAINING SECTION END -------------------\n        // everything that follows now is just diagnostics, prints, logging, etc.\n\n        // todo - move or double-buffer all of this timing logic to avoid idling the GPU at this point!\n        float time_elapsed_ms;\n        cudaCheck(cudaEventElapsedTime(&time_elapsed_ms, start, end));\n        size_t tokens_processed = (size_t)multi_gpu_config.num_processes * B * T * grad_accum_steps;\n        float tokens_per_second = tokens_processed / time_elapsed_ms * 1000.0f;\n        float bias_corrected_ema_tokens_per_second = tokens_per_second; // by default set to non-ema version\n        if (step > 0) { // consider the first batch to be a warmup (e.g. cuBLAS/cuDNN initialisation)\n            total_sum_iteration_time_s += time_elapsed_ms / 1000.0f;\n            // smooth out the tok/s with an exponential moving average, and bias correct just like in AdamW\n            ema_tokens_per_second = 0.95f * ema_tokens_per_second + 0.05f * tokens_per_second;\n            bias_corrected_ema_tokens_per_second = ema_tokens_per_second / (1.0f - powf(0.95f, step));\n        }\n        float mfu = gpt2_estimate_mfu(&model, B * T * grad_accum_steps, time_elapsed_ms / 1000.0f);\n        printf0(\"step %4d/%d | loss %7.6f (%+.2fz)| norm %6.4f (%+.2fz)| lr %.2e | %.2f ms | %.1f%% bf16 MFU | %.0f tok/s\\n\",\n                step + 1, train_num_batches, model.mean_loss, zloss, grad_norm, zgrad, step_learning_rate,\n                time_elapsed_ms, 100*mfu, bias_corrected_ema_tokens_per_second);\n        if(log_gpu_every > 0 && (step + 1) % log_gpu_every == 0) {\n            GPUUtilInfo gpu_info = get_gpu_utilization_info();\n            printf0(\"                  compute %2.1f%% | memory: %2.1f%% | fan: %2d%% | %4d MHz / %4d MHz | %3d W / %3d W | %d°C / %d°C | %s\\n\",\n                    gpu_info.gpu_utilization, gpu_info.mem_utilization, gpu_info.fan, gpu_info.clock, gpu_info.max_clock, gpu_info.power / 1000, gpu_info.power_limit / 1000,\n                    gpu_info.temperature, gpu_info.temp_slowdown, gpu_info.throttle_reason);\n        }\n        logger_log_train(&logger, step, model.mean_loss, step_learning_rate, grad_norm);\n\n        // disable the profiler after 3 steps of optimization\n        if (step == 3) { cudaProfilerStop(); }\n    }\n    // add a total average, for optimizations that are only mild improvements (excluding 1st batch as warmup)\n    printf0(\"total average iteration time: %f ms\\n\", total_sum_iteration_time_s / (train_num_batches-1) * 1000);\n\n    // free and destroy everything\n    cudaCheck(cudaEventDestroy(end));\n    cudaCheck(cudaEventDestroy(start));\n    if (run_hellaswag) { evalloader_free(&eval_loader); }\n    dataloader_free(&train_loader);\n    dataloader_free(&val_loader);\n    tokenizer_free(&tokenizer);\n    free(cpu_logits_raw);\n    free(cpu_logits);\n    free(gen_tokens);\n    multi_gpu_config_free(&multi_gpu_config);\n    gpt2_free(&model);\n    common_free(model);\n    return 0;\n}\n#endif\n"
  },
  {
    "path": "train_gpt2.py",
    "content": "\"\"\"\nReference code for GPT-2 training and inference.\nWill save the model weights into files, to be read from C as initialization.\n\nReferences:\n1) the official GPT-2 TensorFlow implementation released by OpenAI:\nhttps://github.com/openai/gpt-2/blob/master/src/model.py\n2) huggingface/transformers PyTorch implementation:\nhttps://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py\n\nExample launches to only benchmark the speed of bfloat16 compiled GPU training:\n1 GPU:\npython train_gpt2.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16\nyou can also turn on flash-attention by appending --flash=1\n4 GPU:\ntorchrun --standalone --nproc_per_node=4 train_gpt2.py --write_tensors=0 --num_iterations=50 --sequence_length=1024 --compile=1 --tensorcores=1 --dtype=bfloat16\n\"\"\"\n\nimport os\nimport math\nimport glob\nimport struct\nimport inspect\nfrom contextlib import nullcontext\nfrom dataclasses import dataclass\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport torch._inductor.config as config\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom torch.distributed import init_process_group, destroy_process_group\nfrom torch.distributed.optim import ZeroRedundancyOptimizer\nimport torch.distributed as dist\n\n# -----------------------------------------------------------------------------\n# PyTorch nn.Module definitions for the GPT-2 model\n\nclass NewGELU(nn.Module):\n    \"\"\"Careful there are a few versions of GeLU, this one is the exact one used by OpenAI\"\"\"\n    def forward(self, input):\n        return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))\n\n# using a global to toggle flash-attention\nFLASH = 0\n\nclass CausalSelfAttention(nn.Module):\n\n    def __init__(self, config):\n        super().__init__()\n        assert config.n_embd % config.n_head == 0\n        # key, query, value projections for all heads, but in a batch\n        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)\n        # output projection\n        self.c_proj = nn.Linear(config.n_embd, config.n_embd)\n        self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1\n        # regularization\n        self.n_head = config.n_head\n        self.n_embd = config.n_embd\n        # not really a 'bias', more of a mask, but following the OpenAI/HF naming though\n        self.register_buffer(\"bias\", torch.tril(torch.ones(config.block_size, config.block_size))\n                                     .view(1, 1, config.block_size, config.block_size))\n\n    def forward(self, x):\n        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)\n        # calculate query, key, values for all heads in batch and move head forward to be the batch dim\n        qkv = self.c_attn(x)\n        q, k, v = qkv.split(self.n_embd, dim=2)\n        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\n        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\n        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)\n        if FLASH:\n            # flashattention\n            y = F.scaled_dot_product_attention(q, k, v, is_causal=True)\n        else:\n            # manual implementation of attention\n            # this materializes the large (T,T) matrix for all the queries and keys\n            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))\n            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))\n            att = F.softmax(att, dim=-1)\n            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side\n        # output projection\n        y = self.c_proj(y)\n        return y\n\nclass MLP(nn.Module):\n\n    def __init__(self, config):\n        super().__init__()\n        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)\n        self.gelu    = NewGELU()\n        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)\n        self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1\n\n    def forward(self, x):\n        x = self.c_fc(x)\n        x = self.gelu(x)\n        x = self.c_proj(x)\n        return x\n\nclass Block(nn.Module):\n\n    def __init__(self, config):\n        super().__init__()\n        self.ln_1 = nn.LayerNorm(config.n_embd)\n        self.attn = CausalSelfAttention(config)\n        self.ln_2 = nn.LayerNorm(config.n_embd)\n        self.mlp = MLP(config)\n\n    def forward(self, x):\n        x = x + self.attn(self.ln_1(x))\n        x = x + self.mlp(self.ln_2(x))\n        return x\n\n# -----------------------------------------------------------------------------\n# The main GPT-2 model\n\n@dataclass\nclass GPTConfig:\n    block_size: int = 1024\n    vocab_size: int = 50257\n    n_layer: int = 12\n    n_head: int = 12\n    n_embd: int = 768\n\nclass GPT(nn.Module):\n\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n\n        self.transformer = nn.ModuleDict(dict(\n            wte = nn.Embedding(config.vocab_size, config.n_embd),\n            wpe = nn.Embedding(config.block_size, config.n_embd),\n            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),\n            ln_f = nn.LayerNorm(config.n_embd),\n        ))\n        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n        self.lm_head.LLMC_SKIP_INIT = 1 # don't init this one, we will tie weights\n        self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying\n\n        # init all weights, use a torch rng object to be very careful\n        self.init_rng = torch.Generator()\n        self.init_rng.manual_seed(42)\n        self.apply(self._init_weights)\n\n    def _init_weights(self, module):\n        if isinstance(module, nn.Linear):\n            # apply special scaled init to the residual projections, per GPT-2 paper\n            std = 0.02 if not hasattr(module, 'LLMC_RESIDUAL_SCALE_FLAG') else 0.02/math.sqrt(2 * self.config.n_layer)\n            # we want to skip initializing lm_head, which shares parameters with wte\n            # and wte was already initialized down below during the Embedding init\n            if not hasattr(module, 'LLMC_SKIP_INIT'):\n                torch.nn.init.normal_(module.weight, mean=0.0, std=std, generator=self.init_rng)\n            if module.bias is not None:\n                torch.nn.init.zeros_(module.bias)\n        elif isinstance(module, nn.Embedding):\n            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02, generator=self.init_rng)\n\n    def forward(self, idx, targets=None, return_logits=True):\n        device = idx.device\n        b, t = idx.size()\n        assert t <= self.config.block_size, f\"Cannot forward sequence of length {t}, block size is only {self.config.block_size}\"\n        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)\n\n        # forward the GPT model itself\n        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)\n        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)\n        x = tok_emb + pos_emb\n\n        for block in self.transformer.h:\n            x = block(x)\n        x = self.transformer.ln_f(x)\n\n        if targets is not None:\n            # if we are given some desired targets also calculate the loss\n            logits = self.lm_head(x)\n            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)\n        else:\n            # inference-time mini-optimization: only forward the lm_head on the very last position\n            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim\n            loss = None\n\n        # there are performance reasons why not returning logits is prudent, if not needed\n        if not return_logits:\n            logits = None\n\n        return logits, loss\n\n    @classmethod\n    def from_pretrained(cls, model_type):\n        \"\"\"Loads pretrained GPT-2 model weights from huggingface\"\"\"\n        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}\n        from transformers import GPT2LMHeadModel\n        print(\"loading weights from pretrained gpt: %s\" % model_type)\n\n        # n_layer, n_head and n_embd are determined from model_type\n        config_args = {\n            'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params\n            'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params\n            'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params\n            'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params\n        }[model_type]\n        config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints\n        config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints\n        # create a from-scratch initialized minGPT model\n        config = GPTConfig(**config_args)\n        model = GPT(config)\n        sd = model.state_dict()\n        sd_keys = sd.keys()\n        sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param\n\n        # init a huggingface/transformers model\n        model_hf = GPT2LMHeadModel.from_pretrained(model_type)\n        sd_hf = model_hf.state_dict()\n\n        # copy while ensuring all of the parameters are aligned and match in names and shapes\n        sd_keys_hf = sd_hf.keys()\n        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer\n        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)\n        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']\n        # basically the openai checkpoints use a \"Conv1D\" module, but we only want to use a vanilla Linear\n        # this means that we have to transpose these weights when we import them\n        assert len(sd_keys_hf) == len(sd_keys), f\"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}\"\n        for k in sd_keys_hf:\n            if any(k.endswith(w) for w in transposed):\n                # special treatment for the Conv1D weights we need to transpose\n                assert sd_hf[k].shape[::-1] == sd[k].shape\n                with torch.no_grad():\n                    sd[k].copy_(sd_hf[k].t())\n            else:\n                # vanilla copy over the other parameters\n                assert sd_hf[k].shape == sd[k].shape\n                with torch.no_grad():\n                    sd[k].copy_(sd_hf[k])\n\n        return model\n\n    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, zero_stage):\n        # start with all of the candidate parameters\n        param_dict = {pn: p for pn, p in self.named_parameters()}\n        # filter out those that do not require grad\n        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}\n        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.\n        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.\n        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]\n        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]\n        optim_groups = [\n            {'params': decay_params, 'weight_decay': weight_decay},\n            {'params': nodecay_params, 'weight_decay': 0.0}\n        ]\n        num_decay_params = sum(p.numel() for p in decay_params)\n        num_nodecay_params = sum(p.numel() for p in nodecay_params)\n        print0(f\"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters\")\n        print0(f\"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters\")\n        # Create AdamW optimizer and use the fused version if it is available\n        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters\n        use_fused = fused_available and device_type == 'cuda'\n        print0(f\"using fused AdamW: {use_fused}\")\n        if zero_stage == 1:\n            print0(\"using ZeroRedundancyOptimizer\")\n            optimizer = ZeroRedundancyOptimizer(**optim_groups[0], optimizer_class=torch.optim.AdamW,\n                                                lr=learning_rate, betas=betas, fused=use_fused)\n            optimizer.add_param_group(optim_groups[1])\n        else:\n            print0(\"using regular AdamW\")\n            optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused)\n        return optimizer\n\n    @torch.no_grad()\n    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):\n        \"\"\"\n        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete\n        the sequence max_new_tokens times, feeding the predictions back into the model each time.\n        Most likely you'll want to make sure to be in model.eval() mode of operation for this.\n        \"\"\"\n        for _ in range(max_new_tokens):\n            # if the sequence context is growing too long we must crop it at block_size\n            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]\n            # forward the model to get the logits for the index in the sequence\n            logits, _ = self(idx_cond)\n            # pluck the logits at the final step and scale by desired temperature\n            logits = logits[:, -1, :] / temperature\n            # optionally crop the logits to only the top k options\n            if top_k is not None:\n                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n                logits[logits < v[:, [-1]]] = -float('Inf')\n            # apply softmax to convert logits to (normalized) probabilities\n            probs = F.softmax(logits, dim=-1)\n            # sample from the distribution\n            idx_next = torch.multinomial(probs, num_samples=1)\n            # append sampled index to the running sequence and continue\n            idx = torch.cat((idx, idx_next), dim=1)\n\n        return idx\n\n# -----------------------------------------------------------------------------\n# Our own simple Distributed Data Loader\n\ndef _peek_data_shard(filename):\n    # only reads the header, returns header data\n    with open(filename, \"rb\") as f:\n        # first read the header, which is 256 int32 integers (4 bytes each)\n        header = np.frombuffer(f.read(256*4), dtype=np.int32)\n    if header[0] != 20240520:\n        print(\"ERROR: magic number mismatch in the data .bin file!\")\n        print(\"---> HINT: Are you passing in a correct file with --input_bin?\")\n        print(\"---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README\")\n        print(\"---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try\")\n        exit(1)\n    assert header[1] == 1, \"unsupported version\"\n    ntok = header[2] # number of tokens (claimed)\n    return ntok # for now just return the number of tokens\n\ndef _load_data_shard(filename):\n    with open(filename, \"rb\") as f:\n        # first read the header, which is 256 int32 integers (4 bytes each)\n        header = np.frombuffer(f.read(256*4), dtype=np.int32)\n        assert header[0] == 20240520, \"magic number mismatch in the data .bin file\"\n        assert header[1] == 1, \"unsupported version\"\n        ntok = header[2] # number of tokens (claimed)\n        # the rest of it are tokens, stored as uint16\n        tokens = np.frombuffer(f.read(), dtype=np.uint16)\n    assert len(tokens) == ntok, \"number of tokens read does not match header?\"\n    return tokens\n\nclass DistributedDataLoader:\n    def __init__(self, filename_pattern, B, T, process_rank, num_processes):\n        self.process_rank = process_rank\n        self.num_processes = num_processes\n        self.B = B\n        self.T = T\n\n        # glob files that match the pattern\n        self.files = sorted(glob.glob(filename_pattern))\n        assert len(self.files) > 0, f\"did not find any files that match the pattern {filename_pattern}\"\n\n        # load and validate all data shards, count number of tokens in total\n        ntok_total = 0\n        for fname in self.files:\n            shard_ntok = _peek_data_shard(fname)\n            assert shard_ntok >= num_processes * B * T + 1\n            ntok_total += shard_ntok\n        self.ntok_total = ntok_total\n        print0(f\"DataLoader: total number of tokens: {ntok_total:,} across {len(self.files)} files\")\n\n        # kick things off\n        self.current_shard = None\n        self.reset()\n\n    def reset(self):\n        # we're being a bit clever here: if we already had shard 0 loaded,\n        # then don't do the work to reload it, just reset the pointer\n        if self.current_shard != 0:\n            self.current_shard = 0\n            self.tokens = _load_data_shard(self.files[self.current_shard])\n        self.current_position = self.process_rank * self.B * self.T\n\n    def advance(self): # advance to next data shard\n        self.current_shard = (self.current_shard + 1) % len(self.files)\n        self.current_position = self.process_rank * self.B * self.T\n        self.tokens = _load_data_shard(self.files[self.current_shard])\n\n    def next_batch(self):\n        B = self.B\n        T = self.T\n        buf = self.tokens[self.current_position : self.current_position+B*T+1]\n        buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)\n        x = (buf[:-1]).view(B, T) # inputs\n        y = (buf[1:]).view(B, T) # targets\n        # advance the start pointer in current shard\n        self.current_position += B * T * self.num_processes\n        # if loading the next batch would be out of bounds advance the shard\n        if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):\n            self.advance()\n        return x, y\n\n# -----------------------------------------------------------------------------\n# Python -> C bridge utilities for saving params/grads/activations to .bin files\n\ndef write_fp32(tensor, file):\n    t = tensor.detach().cpu().to(torch.float32)\n    b = t.numpy().tobytes()\n    file.write(b)\n\ndef write_bf16(tensor, file):\n    t = tensor.detach().cpu().to(torch.bfloat16)\n    # numpy doesn't have bf16 datatype so we have to trick it\n    t = t.view(torch.int16) # trick: reinterpret as int16\n    b = t.numpy().tobytes()\n    file.write(b)\n\ndef write_tensors(model_tensors, L, file, dtype):\n    # writes the GPT-2 model's weights to a binary file\n    assert dtype in {\"float32\", \"bfloat16\"}\n    write_fun = write_fp32 if dtype == \"float32\" else write_bf16\n    write_fun(model_tensors[\"transformer.wte.weight\"], file) # (V, C)\n    write_fun(model_tensors[\"transformer.wpe.weight\"], file) # (T, C)\n    for i in range(L): # (L, C)\n        write_fun(model_tensors[f\"transformer.h.{i}.ln_1.weight\"], file)\n    for i in range(L): # (L, C)\n        write_fun(model_tensors[f\"transformer.h.{i}.ln_1.bias\"], file)\n    for i in range(L): # (L, 3C, C)\n        write_fun(model_tensors[f\"transformer.h.{i}.attn.c_attn.weight\"], file)\n    for i in range(L): # (L, 3C)\n        write_fun(model_tensors[f\"transformer.h.{i}.attn.c_attn.bias\"], file)\n    for i in range(L): # (L, C, C)\n        write_fun(model_tensors[f\"transformer.h.{i}.attn.c_proj.weight\"], file)\n    for i in range(L): # (L, C)\n        write_fun(model_tensors[f\"transformer.h.{i}.attn.c_proj.bias\"], file)\n    for i in range(L): # (L, C)\n        write_fun(model_tensors[f\"transformer.h.{i}.ln_2.weight\"], file)\n    for i in range(L): # (L, C)\n        write_fun(model_tensors[f\"transformer.h.{i}.ln_2.bias\"], file)\n    for i in range(L): # (L, 4C, C)\n        write_fun(model_tensors[f\"transformer.h.{i}.mlp.c_fc.weight\"], file)\n    for i in range(L): # (L, 4C)\n        write_fun(model_tensors[f\"transformer.h.{i}.mlp.c_fc.bias\"], file)\n    for i in range(L): # (L, C, 4C)\n        write_fun(model_tensors[f\"transformer.h.{i}.mlp.c_proj.weight\"], file)\n    for i in range(L): # (L, C)\n        write_fun(model_tensors[f\"transformer.h.{i}.mlp.c_proj.bias\"], file)\n    write_fun(model_tensors[\"transformer.ln_f.weight\"], file) # (C, )\n    write_fun(model_tensors[\"transformer.ln_f.bias\"], file) # (C, )\n\n@torch.no_grad()\ndef pad_vocab(tensor, multiple=128, value=0):\n    \"\"\"\n    The dimension of the vocab size in GPT-2 is 50,257\n    which is unfortunately a very unfriendly number for a lot of\n    matrix operations on the GPU. So we pad it to the nearest\n    friendlier multiple, e.g. 50,304 if multiple=128 when we\n    export the weights into C land. This is a NOOP algorithmically\n    and is only done to make the tensor operations more efficient.\n    \"\"\"\n    assert tensor.ndim == 2\n    V, C = tensor.shape\n    assert V == 50257, \"just being defensive here\"\n    # calculate padded vocab size by rounding up to nearest multiple\n    Vp = ((V + multiple - 1) // multiple) * multiple\n    # pad the tensor\n    pad_rows = Vp - V\n    padded = tensor if pad_rows == 0 else F.pad(tensor, (0, 0, 0, pad_rows), value=value)\n    assert padded.shape == (Vp, C)\n    return padded\n\ndef write_model(model, filename, dtype):\n    # everything we need to instantiate the model\n    # 1) header is: version int, GPTConfig ints, padding to 1024 bytes\n    assert dtype in {\"float32\", \"bfloat16\"} # float16 todo maybe later\n    version = {\n        \"float32\": 3, # 3: all tensors are fp32, padded vocab\n        \"bfloat16\": 5, # 5: all tensors are bf16, padded vocab\n    }[dtype]\n    header = torch.zeros(256, dtype=torch.int32)\n    header[0] = 20240326 # magic\n    header[1] = version # checkpoint version\n    header[2] = model.config.block_size\n    header[3] = model.config.vocab_size\n    header[4] = model.config.n_layer\n    header[5] = model.config.n_head\n    header[6] = model.config.n_embd\n    # 2) the parameters follow the header\n    params = {name: param.cpu() for name, param in model.named_parameters()}\n    # pad the vocab to a multiple of 128 here at export, for efficiency in C\n    wte = params[\"transformer.wte.weight\"] # (V, C)\n    wte_padded = pad_vocab(wte) # (Vp, C)\n    params[\"transformer.wte.weight\"] = wte_padded # (Vp, C)\n    print(f\"padded vocab size from {wte.size(0)} to {wte_padded.size(0)}\")\n    header[7] = wte_padded.size(0) # padded vocab size store in header\n    # now write to file\n    with open(filename, \"wb\") as file:\n        file.write(header.numpy().tobytes()) # header\n        write_tensors(params, model.config.n_layer, file, dtype) # params\n    print(f\"wrote {filename}\")\n\ndef write_state(model, x, y, logits, loss, filename):\n    # the state is used for debugging.\n    # it contains information about the input, logits, loss, and the parameter gradients\n    # this can be used for checking the computation correctness in C\n    header = torch.zeros(256, dtype=torch.int32)\n    header[0] = 20240327 # magic\n    header[1] = 2 # run state version = 2 (1 -> 2 for padded vocab changes)\n    header[2] = x.size(0) # batch size of the batch, B\n    header[3] = x.size(1) # temporal extent of the batch, T\n    grads = {name: param.grad.cpu() for name, param in model.named_parameters()}\n    # pad the vocab grads here as well, to mirror write_model\n    wte_grad = grads[\"transformer.wte.weight\"] # (V, C)\n    wte_grad_padded = pad_vocab(wte_grad, value=0) # (Vp, C) # TODO later maybe pad with nan?\n    grads[\"transformer.wte.weight\"] = wte_grad_padded # (Vp, C)\n    print(f\"padded vocab size in reference grads from {wte_grad.size(0)} to {wte_grad_padded.size(0)}\")\n    with open(filename, \"wb\") as file:\n        # header\n        file.write(header.numpy().tobytes())\n        # input x\n        file.write(x.cpu().numpy().astype(\"int32\").tobytes()) # (B, T)\n        # targets y\n        file.write(y.cpu().numpy().astype(\"int32\").tobytes()) # (B, T)\n        # logits (result of the model forward pass)\n        write_fp32(logits.cpu(), file)\n        # loss (single float, result of the cross entropy loss)\n        write_fp32(loss.cpu(), file)\n        # gradients\n        write_tensors(grads, model.config.n_layer, file, \"float32\")\n    print(f\"wrote {filename}\")\n\ndef write_tokenizer(enc, filename):\n    n = enc.max_token_value + 1\n    header = torch.zeros(256, dtype=torch.int32)\n    header[0] = 20240328 # magic\n    header[1] = 2 # tokenizer version = 2 (1 -> 2: includes EOT token)\n    header[2] = n # number of tokens\n    header[3] = enc.eot_token # EOT token\n    with open(filename, \"wb\") as file:\n        file.write(header.numpy().tobytes())\n        for i in range(n):\n            b = enc.decode_bytes([i])\n            length = len(b)\n            assert length < 256, f\"Token length exceeds 255: {length}\"\n            file.write(struct.pack(\"<B\", length))  # Write the length as a 1-byte unsigned integer\n            file.write(b)  # Write the actual bytes\n    print(f\"wrote {filename}\")\n\n# -----------------------------------------------------------------------------\n# int main\n\ndef print0(*args, **kwargs):\n    # modified print that only prints from the master process\n    # if this is not a distributed run, it's just a print\n    if int(os.environ.get(\"RANK\", 0)) == 0:\n        print(*args, **kwargs)\n\nif __name__ == \"__main__\":\n    import time\n    import argparse\n    import tiktoken\n    print0(f\"Running pytorch {torch.version.__version__}\")\n\n    # default settings will overfit a tiny batch of data\n    # and save model weights and debug state to disk on the first iteration\n    parser = argparse.ArgumentParser()\n    # file system input / output\n    parser.add_argument(\"--input_bin\", type=str, default=\"dev/data/tinyshakespeare/tiny_shakespeare_val.bin\", help=\"input .bin to train on\")\n    parser.add_argument(\"--input_val_bin\", type=str, default=\"\", help=\"input .bin to eval validation loss on\")\n    parser.add_argument(\"--output_dir\", type=str, default=\"\", help=\"output directory to which to write logs and checkpoints\")\n    parser.add_argument(\"--model\", type=str, default=\"gpt2\", help=\"gpt2|gpt2-medium|gpt2-large|gpt2-xl|d12|d24|d36|d48\")\n    # token layout for each step of the optimization\n    parser.add_argument(\"--batch_size\", type=int, default=4, help=\"batch size, in units of #batch dimensions\")\n    parser.add_argument(\"--sequence_length\", type=int, default=64, help=\"sequence length\")\n    parser.add_argument(\"--total_batch_size\", type=int, default=256, help=\"total desired batch size, in units of #tokens\")\n    # workload (number of steps)\n    parser.add_argument(\"--num_iterations\", type=int, default=10, help=\"number of iterations to run\")\n    parser.add_argument(\"--inference_only\", type=int, default=0, help=\"only run inference\")\n    # optimization\n    parser.add_argument(\"--learning_rate\", type=float, default=1e-4, help=\"learning rate warmup iterations\")\n    parser.add_argument(\"--warmup_iters\", type=int, default=0, help=\"learning rate warmup iterations\")\n    parser.add_argument(\"--learning_rate_decay_frac\", type=float, default=1.0, help=\"learning rate warmup iterations\")\n    parser.add_argument(\"--weight_decay\", type=float, default=0.0, help=\"weight decay\")\n    parser.add_argument(\"--grad_clip\", type=float, default=1.0, help=\"maximum gradient magnitude\")\n    # evaluation\n    parser.add_argument(\"--val_loss_every\", type=int, default=0, help=\"every how mant steps to evaluate val loss?\")\n    parser.add_argument(\"--val_max_steps\", type=int, default=20, help=\"how many batches of val to average?\")\n    parser.add_argument(\"--sample_every\", type=int, default=0, help=\"how often to sample from the model?\")\n    # debugging\n    parser.add_argument(\"--overfit_single_batch\", type=int, default=1, help=\"overfit just one batch of data\")\n    # numerics\n    parser.add_argument(\"--tensorcores\", type=int, default=0, help=\"use tensorcores\")\n    # memory management\n    parser.add_argument(\"--device\", type=str, default=\"\", help=\"by default we autodetect, or set it here\")\n    parser.add_argument(\"--compile\", type=int, default=0, help=\"torch.compile the model\")\n    parser.add_argument(\"--flash\", type=int, default=0, help=\"use flash attention\")\n    parser.add_argument(\"--dtype\", type=str, default=\"float32\", help=\"float32|float16|bfloat16\")\n    parser.add_argument(\"--zero_stage\", type=int, default=0, help=\"zero redundancy optimizer stage (0/1/2/3)\")\n    # python -> C bridge\n    parser.add_argument(\"--write_tensors\", type=int, default=1, help=\"write tensors to disk\")\n    args = parser.parse_args()\n\n    # args error checking and convenience variables\n    B, T = args.batch_size, args.sequence_length\n    assert 1 <= T <= 1024\n    assert args.dtype in {\"float32\", \"float16\", \"bfloat16\"}\n    assert args.model in {\"gpt2\", \"gpt2-medium\", \"gpt2-large\", \"gpt2-xl\", \"d12\", \"d24\", \"d36\", \"d48\"}\n\n    # set up DDP (distributed data parallel). torchrun sets this env variable\n    ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?\n    if ddp:\n        # use of DDP atm demands CUDA, we set the device appropriately according to rank\n        assert torch.cuda.is_available(), \"for now i think we need CUDA for DDP\"\n        init_process_group(backend='nccl')\n        ddp_rank = int(os.environ['RANK'])\n        ddp_local_rank = int(os.environ['LOCAL_RANK'])\n        ddp_world_size = int(os.environ['WORLD_SIZE'])\n        device = f'cuda:{ddp_local_rank}'\n        torch.cuda.set_device(device)\n        master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.\n        seed_offset = 0 # each process gets the exact same seed\n        zero_stage = args.zero_stage\n    else:\n        ddp_rank = 0\n        ddp_local_rank = 0\n        zero_stage = 0\n        ddp_world_size = 1\n        master_process = True\n        seed_offset = 0\n        # select the device\n        if args.device:\n            # provided explicitly by the user\n            device = args.device\n        else:\n            # attempt to autodetect the device\n            device = \"cpu\"\n            if torch.cuda.is_available():\n                device = \"cuda\"\n            elif hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n                device = \"mps\"\n    print(f\"using device: {device}\")\n    device_type = 'cuda' if 'cuda' in device else 'cpu'\n\n    # calculate gradient accumulation from the desired total batch size and the current run configuration\n    tokens_per_fwdbwd = B * T * ddp_world_size\n    assert args.total_batch_size % tokens_per_fwdbwd == 0\n    grad_accum_steps = args.total_batch_size // tokens_per_fwdbwd\n    print0(f\"total desired batch size: {args.total_batch_size}\")\n    print0(f\"=> calculated gradient accumulation steps: {grad_accum_steps}\")\n\n    # set up a context manager following the desired dtype and device\n    ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]\n    ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == \"cuda\" else nullcontext()\n\n    # rng / reproducibility\n    torch.manual_seed(42)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed(42)\n\n    # set the torch precision mode to use TensorFloat32 (TF32) for matmuls\n    # docs https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html\n    if args.tensorcores:\n        torch.set_float32_matmul_precision('high')\n\n    # turn on/off flash attention\n    assert args.flash in {0, 1}\n    FLASH = args.flash\n\n    # init (and write) the tokenizer\n    enc = tiktoken.get_encoding(\"gpt2\")\n    if master_process and args.write_tensors: # tokenizer is technically not tensors but ok\n        write_tokenizer(enc, \"gpt2_tokenizer.bin\")\n\n    # init the model, either from scratch or from OpenAI pretrained checkpoint\n    if args.model[0] == \"d\":\n        # from scratch (random weights)\n        model_config = {\n            \"d12\": GPTConfig(block_size=1024, vocab_size=50257, n_layer=12, n_head=12, n_embd=768),\n            \"d24\": GPTConfig(block_size=1024, vocab_size=50257, n_layer=24, n_head=16, n_embd=1024),\n            \"d36\": GPTConfig(block_size=1024, vocab_size=50257, n_layer=36, n_head=20, n_embd=1280),\n            \"d48\": GPTConfig(block_size=1024, vocab_size=50257, n_layer=48, n_head=25, n_embd=1600),\n        }[args.model]\n        model = GPT(model_config)\n    else:\n        # load the GPT-2 model weights\n        model = GPT.from_pretrained(args.model)\n    model.train()\n    model.to(device)\n    if args.compile:\n        if hasattr(config, \"coordinate_descent_tuning\"):\n            config.coordinate_descent_tuning = True # suggested by @Chillee\n        print0(\"compiling the model...\")\n        model = torch.compile(model)\n\n    # -------------------------------------------------------------------------\n    # Our own version of a simple DistributedDataLoader\n\n    # load tokens\n    train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)\n    val_loader = None\n    if args.input_val_bin:\n        val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)\n\n    # -------------------------------------------------------------------------\n    # PyTorch -> C bridge: save some weights and state for C to load later as reference\n\n    # do one forward pass to generate ground truth for our C tests\n    if master_process and args.write_tensors and (not args.inference_only):\n        x, y = train_loader.next_batch()\n        x, y = x.to(device), y.to(device)\n        logits, loss = model(x, y)\n        loss.backward()\n        # save model params, in both float32 and bfloat16\n        model_to_size = {\"gpt2\": \"124M\", \"gpt2-medium\": \"355M\", \"gpt2-large\": \"774M\", \"gpt2-xl\": \"1558M\"}\n        model_to_size.update({f\"d{d}\": f\"d{d}\" for d in [12, 24, 36, 48]})\n        model_size_str = model_to_size[args.model] # e.g. \"124M\", or \"d12\"\n        write_model(model, f\"gpt2_{model_size_str}.bin\", dtype=\"float32\")\n        write_model(model, f\"gpt2_{model_size_str}_bf16.bin\", dtype=\"bfloat16\")\n        # save x, y, logits, loss, and parameter gradients, for debugging C\n        # always store these in fp32 to have an accurate reference (?)\n        write_state(model, x, y, logits, loss, f\"gpt2_{model_size_str}_debug_state.bin\")\n        # reset the train_loader for the optimization below\n        train_loader.reset()\n\n    # -------------------------------------------------------------------------\n    # main training loop\n\n    # here we wrap model into DDP container\n    if ddp:\n        model = DDP(model, device_ids=[ddp_local_rank])\n    raw_model = model.module if ddp else model # always contains the \"raw\" unwrapped model\n\n    # init the optimizer\n    optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay,\n                                               learning_rate=args.learning_rate, betas=(0.9, 0.95),\n                                               device_type=device, zero_stage=zero_stage)\n\n    # learning rate decay scheduler (cosine with warmup)\n    def get_lr(it):\n        min_lr = args.learning_rate * args.learning_rate_decay_frac\n        # 1) linear warmup for warmup_iters steps\n        if it < args.warmup_iters:\n            return args.learning_rate * (it+1) / args.warmup_iters\n        # 2) if it > lr_decay_iters, return min learning rate\n        if it > args.num_iterations:\n            return min_lr\n        # 3) in between, use cosine decay down to min learning rate\n        decay_ratio = (it - args.warmup_iters) / (args.num_iterations - args.warmup_iters)\n        assert 0 <= decay_ratio <= 1\n        coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0\n        return min_lr + coeff * (args.learning_rate - min_lr)\n\n    # create the logging directory if it does not exist\n    logfile = None\n    if args.output_dir:\n        os.makedirs(args.output_dir, exist_ok=True)\n        logfile = os.path.join(args.output_dir, \"main.log\")\n        # create the log file \"main.log\" inside it, and wipe it clean\n        with open(logfile, \"w\") as f:\n            pass\n\n    if device == \"cuda\":\n        torch.cuda.reset_peak_memory_stats()\n    timings = []\n    norm = -1.0   # dummy value to print in inference-only mode\n    for step in range(args.num_iterations + 1):\n        t0 = time.time()\n        last_step = (step == args.num_iterations)\n\n        # once in a while evaluate the validation dataset\n        if (args.val_loss_every > 0 \\\n            and (step % args.val_loss_every == 0 or last_step)) \\\n            and (val_loader is not None):\n            model.eval()\n            val_loader.reset()\n            with torch.no_grad():\n                val_loss = 0.0\n                for _ in range(args.val_max_steps):\n                    x, y = val_loader.next_batch()\n                    x, y = x.to(device), y.to(device)\n                    _, loss = model(x, y, return_logits=False)\n                    val_loss += loss.item()\n                val_loss /= args.val_max_steps\n            # log to console and to file\n            print0(f\"val loss {val_loss}\")\n            if master_process and logfile is not None:\n                with open(logfile, \"a\") as f:\n                    f.write(\"s:%d tel:%f\\n\" % (step, val_loss))\n\n        # once in a while perform model inference on the master process\n        if (args.sample_every > 0 \\\n            and (step % args.sample_every == 0 or last_step)) \\\n            and master_process:\n            model.eval()\n            # before we end, let's also do one round of inference\n            # we'll kick off the generation with \"<|endoftext|>\", which designates the start of a new sequence\n            start_ids = [enc.eot_token]\n            xg = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])\n            max_new_tokens = 32\n            temperature = 1.0\n            top_k = 40\n            yg = raw_model.generate(xg, max_new_tokens, temperature=temperature, top_k=top_k)\n            print0('---------------')\n            print0(enc.decode(yg[0].tolist()))\n            print0('---------------')\n\n        # bit confusing: we want to make sure to eval and sample on 0th iteration\n        # but also after the very last iteration. so we loop for step <= num_iterations\n        # instead of just < num_iterations (one extra due to <=), only to do\n        # the validation/sampling one last time, and then we break right here as we're done.\n        if last_step:\n            break\n\n        # --------------- TRAINING SECTION BEGIN -----------------\n        model.train()\n        optimizer.zero_grad(set_to_none=True)\n        # if we are trying to overfit a single batch, we reset the loader here\n        if args.overfit_single_batch:\n            train_loader.reset()\n        # micro-batch loop where we do gradient accumulation to reach desired total batch size\n        lossf = 0.0 # for getting the mean loss (as simple float) over the accumulation steps\n        for micro_step in range(grad_accum_steps):\n            # fetch a batch\n            x, y = train_loader.next_batch()\n            x, y = x.to(device), y.to(device)\n            if ddp:\n                # we want only the last micro-step to sync grads in a DDP model\n                # the official way to do this is with model.no_sync(), but that is a\n                # context manager that bloats the code, so we just toggle this variable\n                model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)\n            # forward pass\n            with ctx:\n                _, loss = model(x, y, return_logits=False)\n                # we have to scale the loss to account for gradient accumulation,\n                # because the gradients just add on each successive backward().\n                # addition of gradients corresponds to a SUM in the objective, but\n                # instead of a SUM we want MEAN, so we scale the loss here\n                loss = loss / grad_accum_steps\n                lossf += loss.detach() # keep track of the mean loss\n            # backward pass\n            if not args.inference_only:\n                loss.backward()\n        if ddp:\n            dist.all_reduce(lossf, op=dist.ReduceOp.AVG)\n        lossf = lossf.item()\n        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)\n        # determine and set the learning rate for this iteration\n        lr = get_lr(step)\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = lr\n        # step the optimizer\n        optimizer.step()\n        # --------------- TRAINING SECTION END -------------------\n        # everything that follows now is just diagnostics, prints, logging, etc.\n\n        # wait on the CPU for all device work to end so we get accurate per-iteration timings below\n        if device == \"mps\":\n            torch.mps.synchronize()\n        elif device == \"cuda\":\n            torch.cuda.synchronize()\n        # time and print\n        t1 = time.time()\n        # the 0th iteration is often an outlier (much slower) => skip logging it\n        tokens_per_second = grad_accum_steps * ddp_world_size * B * T / (t1-t0)\n        print0(f\"step {step+1:4d}/{args.num_iterations} | train loss {lossf:.6f} | norm {norm:.4f} | lr {lr:.2e} | ({(t1-t0)*1000:.2f} ms | {tokens_per_second:.0f} tok/s)\")\n        # log to logile\n        if master_process and logfile is not None:\n            with open(logfile, \"a\") as f:\n                f.write(\"s:%d trl:%f\\n\" % (step, lossf))\n\n        # keep track of smooth timings, last 20 iterations\n        if step > 0 and step > args.num_iterations - 20:\n            timings.append(t1-t0)\n\n    # print the average of the last 20 timings, to get something smooth-ish\n    timings = timings[-20:]\n    print0(f\"final {len(timings)} iters avg: {np.mean(timings)*1000:.3f}ms\")\n    print0(f\"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB\")\n\n    # -------------------------------------------------------------------------\n    # clean up nice\n    if ddp:\n        destroy_process_group()\n"
  },
  {
    "path": "train_gpt2_fp32.cu",
    "content": "/*\nGPT-2 Transformer Neural Net trained in raw CUDA\nNon-trivial notes to be aware of:\n\nWe are being clever in the backward pass to conserve memory.\nIn particular, all parameters use a += in the backward pass, so we\ncan later do gradient accumulation. But all activations have = instead of +=\nbecause these are faster (just read, no write). This is okay for all activations\nexcept for those in the residual stream, where the gradients have to add. We make\nsure that those parts work out ok and that we do a += as necessary. E.g.,\nthe layernorms are connected to the residuals so we += in layernorm backward.\n*/\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <math.h>\n#include <time.h>\n#include <assert.h>\n#include <float.h>\n#include <string.h>\n#include <unistd.h>\n\n// GPU / CUDA related\n#include <cublas_v2.h>\n#include <cuda_runtime.h>\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n// our own utilities\n// defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck\n#include \"llmc/utils.h\"\n// defines: tokenizer_init, tokenizer_decode, tokenizer_free\n#include \"llmc/tokenizer.h\"\n// defines: dataloader_init, dataloader_reset, dataloader_next_batch, dataloader_free\n#include \"llmc/dataloader.h\"\n\n// ----------------------------------------------------------------------------\n// CUDA utils\n\n// convenience macro for calculating grid/block dimensions for kernels\n#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))\n\n// CUDA error checking\nvoid cudaCheck(cudaError_t error, const char *file, int line) {\n  if (error != cudaSuccess) {\n    printf(\"[CUDA ERROR] at file %s:%d:\\n%s\\n\", file, line,\n           cudaGetErrorString(error));\n    exit(EXIT_FAILURE);\n  }\n};\n#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__))\n\n// cuBLAS error checking\nvoid cublasCheck(cublasStatus_t status, const char *file, int line)\n{\n    if (status != CUBLAS_STATUS_SUCCESS) {\n        printf(\"[cuBLAS ERROR]: %d %s %d\\n\", status, file, line);\n        exit(EXIT_FAILURE);\n    }\n}\n#define cublasCheck(status) { cublasCheck((status), __FILE__, __LINE__); }\n\nstatic cublasComputeType_t cublas_compute_type;\ncublasHandle_t cublas_handle;\n\nnamespace cg = cooperative_groups;\n\n// ----------------------------------------------------------------------------\n// all the kernels\n\n__device__ inline float4 add_float4(const float4& a, const float4& b) {\n    return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);\n}\n\n// use of float4 leads to using 128-bit LDG / STG instructions in SASS,\n// very helpful in memory-bound kernels like encoder_forward\n__global__ void encoder_forward_kernel3(float4* out,\n                               const int* inp, const float4* wte, const float4* wpe,\n                               int B, int T, int C) {\n    int C4 = C / 4;\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    int N = B * T * C4;\n    if (idx < N) {\n        int bt = idx / C4;\n        int b = bt / T;\n        int t = bt % T;\n        int c4 = idx % C4;\n        int ix = inp[b * T + t];\n        out[b * T * C4 + t * C4 + c4] = add_float4(wte[ix * C4 + c4], wpe[t * C4 + c4]);\n    }\n}\n\n// really bad naive kernel with atomicAdd\n__global__ void encoder_backward_kernel(float* dwte, float* dwpe,\n                                        const float* dout, const int* inp,\n                                        int B, int T, int C) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    int N = B * T * C;\n\n    if (idx < N) {\n        int bt = idx / C;\n        int b = bt / T;\n        int t = bt % T;\n        int c = idx % C;\n\n        int ix = inp[b * T + t];\n\n        const float* dout_btc = dout + b * T * C + t * C + c;\n        float* dwte_ix = dwte + ix * C + c;\n        float* dwpe_tc = dwpe + t * C + c;\n\n        atomicAdd(dwte_ix, *dout_btc);\n        atomicAdd(dwpe_tc, *dout_btc);\n    }\n}\n\n__global__ void layernorm_forward_kernel3(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd,\n                                    const float*  __restrict__ inp, const float*  __restrict__ weight,\n                                    const float* __restrict__ bias, int N, int C) {\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n    if(idx >= N) {\n        return;\n    }\n\n    // the row of input that this group of threads is responsible for\n    const float* x = inp + idx * C;\n\n    // mean\n    float sum = 0.0f;\n    for (int i = warp.thread_rank(); i < C; i += warp.size()) {\n        sum += x[i];\n    }\n    sum = cg::reduce(warp, sum, cg::plus<float>{});\n    float m = sum / C;\n    if(warp.thread_rank() == 0 && mean != nullptr) {\n        __stcs(mean + idx, m);\n    }\n\n    // rstd\n    sum = 0.0f;\n    for (int i = warp.thread_rank(); i < C; i += warp.size()) {\n        float diff = x[i] - m;\n        sum += diff * diff;\n    }\n    sum = cg::reduce(warp, sum, cg::plus<float>{});\n    float s = rsqrtf(sum / C + 1e-5f);\n    if(warp.thread_rank() == 0 && rstd != nullptr) {\n        __stcs(rstd + idx, s);\n    }\n\n    // final normalization and scaling by weight/bias\n    float* o = out + idx * C;\n    for (int c = warp.thread_rank(); c < C; c += warp.size()) {\n        // load and store using the .cs \"streaming\" hint to the compiler,\n        // indicating that this data will not be reused soon, and can be streamed through the caches\n        // this allows the threads to get more cache-hits for the (shared) weight and bias parameters\n        float n = s * (__ldcs(x+c) - m);\n        __stcs(o+c, n * weight[c] + bias[c]);\n    }\n}\n\n__global__ void permute_kernel(float* q, float* k, float* v,\n                               const float* inp,\n                               int B, int N, int NH, int d) {\n    // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d)\n    // but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d)\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    // Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_]\n    if (idx < B * NH * N * d) {\n        int b = idx / (NH * N * d);\n        int rest = idx % (NH * N * d);\n        int nh_ = rest / (N * d);\n        rest = rest % (N * d);\n        int n = rest / d;\n        int d_ = rest % d;\n        int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_;\n        q[idx] = __ldcs(&inp[inp_idx]);\n        k[idx] = __ldcs(&inp[inp_idx + NH * d]);\n        v[idx] = __ldcs(&inp[inp_idx + 2 * (NH * d)]);\n    }\n}\n\n__global__ void permute_kernel_backward(float* dinp,\n                                        const float* dq, const float* dk, const float* dv,\n                                        int B, int N, int NH, int d) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < B * NH * N * d) {\n        int b = idx / (NH * N * d);\n        int rest = idx % (NH * N * d);\n        int nh_ = rest / (N * d);\n        rest = rest % (N * d);\n        int n = rest / d;\n        int d_ = rest % d;\n\n        int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_;\n        dinp[inp_idx] = dq[idx];\n        dinp[inp_idx + NH * d] = dk[idx];\n        dinp[inp_idx + 2 * (NH * d)] = dv[idx];\n    }\n}\n\n__global__ void unpermute_kernel(float* inp, float *out, int B, int N, int NH, int d) {\n   // out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d)\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    // out[b][n][nh_][d_] <- inp[b][nh_][n][d_]\n    if (idx < B * NH * N * d) {\n        int b = idx / (NH * N * d);\n        int rest = idx % (NH * N * d);\n        int nh_ = rest / (N * d);\n        rest = rest % (N * d);\n        int n = rest / d;\n        int d_ = rest % d;\n        int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;\n        out[other_idx] = __ldcs(&inp[idx]);\n    }\n}\n\n__global__ void unpermute_kernel_backward(float* dinp, const float *dout, int B, int N, int NH, int d) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < B * NH * N * d) {\n        int b = idx / (NH * N * d);\n        int rest = idx % (NH * N * d);\n        int nh_ = rest / (N * d);\n        rest = rest % (N * d);\n        int n = rest / d;\n        int d_ = rest % d;\n        int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;\n        dinp[idx] = dout[other_idx];\n    }\n}\n\n__device__ float& vec_at(float4& vec, int index) {\n    return reinterpret_cast<float*>(&vec)[index];\n}\n\n__device__ float vec_at(const float4& vec, int index) {\n    return reinterpret_cast<const float*>(&vec)[index];\n}\n\n__global__ void softmax_forward_kernel5(float* out, float inv_temperature, const float* inp, int N, int T) {\n    // inp, out shape: (N, T, T), where N = B * NH\n    // fuses the multiplication by scale inside attention\n    // directly autoregressive, so we only compute the lower triangular part\n    // uses the online softmax algorithm\n    assert(T % 4  == 0);\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    // micro-optimization: we iterate backwards so that\n    // after the softmax backward operation completes, the cache retains the\n    // part of the matrix close to the upper left corner, which benefits the\n    // matmul operation that immediately follows.\n    // int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); // forward order\n    int idx = (gridDim.x - blockIdx.x - 1) * warp.meta_group_size() + warp.meta_group_rank(); // backward order\n    if(idx >= N * T) {\n        return;\n    }\n    int own_pos = idx % T;\n    int pos_by_4 = own_pos / 4;\n\n    // one row of inp, i.e. inp[idx, :] of shape (T,)\n    const float* x = inp + idx * T;\n\n    // not INF, so we don't get NaNs accidentally when subtracting two values.\n    float maxval = -FLT_MAX;\n    float sumval = 0.0f;\n\n    const float4* x_vec = reinterpret_cast<const float4*>(x);\n    for (int i = warp.thread_rank(); i < pos_by_4; i += warp.size()) {\n        float4 v = x_vec[i];\n        float old_maxval = maxval;\n        for(int k = 0; k < 4; ++k) {\n            maxval = fmaxf(maxval, vec_at(v, k));\n        }\n        sumval *= expf(inv_temperature * (old_maxval - maxval));\n        for(int k = 0; k < 4; ++k) {\n            sumval += expf(inv_temperature * (vec_at(v, k) - maxval));\n        }\n    }\n\n    if(4*pos_by_4 + warp.thread_rank() <= own_pos) {\n        float old_maxval = maxval;\n        maxval = fmaxf(maxval, x[4*pos_by_4 + warp.thread_rank()]);\n        sumval *= expf(inv_temperature * (old_maxval - maxval));\n        sumval += expf(inv_temperature * (x[4*pos_by_4 + warp.thread_rank()] - maxval));\n    }\n\n    float global_maxval = cg::reduce(warp, maxval, cg::greater<float>{});\n    sumval *= expf(inv_temperature * (maxval - global_maxval));\n\n    float sum = cg::reduce(warp, sumval, cg::plus<float>{});\n    float norm = 1.f / sum;\n\n    // divide the whole row by the sum\n    for (int i = warp.thread_rank(); i <= own_pos; i += warp.size()) {\n        // recalculation is faster than doing the round-trip through memory.\n        float ev = expf(inv_temperature * (__ldcs(x + i) - global_maxval));\n        __stcs(out + idx * T + i, ev * norm);\n    }\n}\n\n__global__ void residual_forward_kernel(float* out, float* inp1, float* inp2, int N) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < N) {\n        out[idx] = __ldcs(&inp1[idx]) + __ldcs(&inp2[idx]);\n    }\n}\n\n#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI)\n__global__ void gelu_forward_kernel(float* out, const float* inp, int N) {\n    int i = blockIdx.x * blockDim.x + threadIdx.x;\n    if (i < N) {\n        float xi = inp[i];\n        float cube = 0.044715f * xi * xi * xi;\n        out[i] = 0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube)));\n    }\n}\n\n__global__ void gelu_backward_kernel(float* dinp, const float* inp, const float* dout, const int N) {\n    int i = blockIdx.x * blockDim.x + threadIdx.x;\n    if (i < N) {\n        float x = inp[i];\n        float cube = 0.044715f * x * x * x;\n        float tanh_arg = GELU_SCALING_FACTOR * (x + cube);\n        float tanh_out = tanhf(tanh_arg);\n        float coshf_out = coshf(tanh_arg);\n        float sech_out = 1.0f / (coshf_out * coshf_out);\n        float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);\n        dinp[i] = local_grad * dout[i];\n    }\n}\n\n// this kernel performs a column-wise reduction over dout, in PyTorch equivalent to:\n// dbias = dout.sum((0,1))\n// the idea is to employ one block to reduce along several columns,\n// where each block has a width of 32 columns to ensure coalesced access.\n// at the end we accumulate the reductions performed by the warps in each block via shared memory\n__global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) {\n    // this kernel is launched with 1D grid_dim of OC/32\n    // for example let's say block_size is 128\n    extern __shared__ float smem[]; // of size block_size (128)\n    const int warp_id = threadIdx.x / warpSize; // warp index in the block, 0,1,2,3\n    const int lane_id = threadIdx.x % warpSize; // thread index in the warp, 0,1,2,...,31\n    const int tl = blockIdx.x * warpSize; // pointer to the start column for this block\n    const int vstep = blockDim.x / warpSize; // number of warps in a block, e.g. 4\n\n    // pointer to the start of the column for one lane of threads\n    // so e.g. 4 threads (of the same lane_id) will reduce this one column\n    const float* dout_col = dout + tl + lane_id;\n\n    // column reductions by looping through the rows\n    // each of the 4 threads offsets by its warp_id and then skips by vstep\n    // together these 4 threads cover all B*T rows of this (lane_id) column\n    // importantly, consecutive threads (in threadId) are processing adjacent columns,\n    // leading to a coalesced memory access pattern\n    float dout_sum = 0.0f;\n    for (int row = warp_id; row < B * T; row += vstep) {\n        dout_sum += dout_col[row * OC];\n    }\n    smem[lane_id + warp_id * warpSize] = dout_sum;\n    __syncthreads();\n\n    // warp_id 0 reduces the shared memory column-wise, linearly\n    dout_sum = 0.0f;\n    if (warp_id == 0) {\n        for (int j = 0; j < vstep; j++) {\n            dout_sum += smem[lane_id + j * warpSize];\n        }\n        dbias[tl + lane_id] += dout_sum;\n    }\n}\n\n// uses shared memory instead for the reduces\n__global__ void layernorm_backward_kernel2(float* dinp, float* dweight, float* dbias,\n                                           const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd,\n                                           int B, int T, int C) {\n    extern __shared__ float shared[]; // size = 2 * C\n\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();\n    int N = B * T;\n    if(idx >= N) { return; } // thread guards\n\n    int b = idx / T;\n    int t = idx % T;\n\n    const float* dout_bt = dout + b * T * C + t * C;\n    const float* inp_bt = inp + b * T * C + t * C;\n    float* dinp_bt = dinp + b * T * C + t * C;\n    const float mean_bt = mean[b * T + t];\n    const float rstd_bt = rstd[b * T + t];\n\n    // the first half of shared memory is bias, second is weight\n    float* dbias_shared = shared;\n    float* dweight_shared = shared + C;\n\n    // init shared memory to zero\n    #pragma unroll\n\tfor(int i = threadIdx.x; i < C; i+= blockDim.x){\n       dbias_shared[i] = 0.0f;\n       dweight_shared[i] = 0.0f;\n    }\n    __syncthreads();\n\n    // first: two reduce operations\n    float dnorm_mean = 0.0f;\n    float dnorm_norm_mean = 0.0f;\n    for (int i = warp.thread_rank(); i < C; i  += warp.size()) {\n        float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;\n        float dnorm_i = weight[i] * dout_bt[i];\n        dnorm_mean += dnorm_i;\n        dnorm_norm_mean += dnorm_i * norm_bti;\n    }\n    dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus<float>{});\n    dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus<float>{});\n    dnorm_mean = dnorm_mean / C;\n    dnorm_norm_mean = dnorm_norm_mean / C;\n\n    // now iterate again and accumulate all the gradients\n    for (int i = warp.thread_rank(); i < C; i += warp.size()) {\n        float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;\n        float dnorm_i = weight[i] * dout_bt[i];\n        // gradient contribution to bias\n        atomicAdd(&dbias_shared[i], dout_bt[i]);\n        // gradient contribution to weight\n        atomicAdd(&dweight_shared[i], norm_bti * dout_bt[i]);\n        // gradient contribution to input\n        float dval = 0.0f;\n        dval += dnorm_i; // term 1\n        dval -= dnorm_mean; // term 2\n        dval -= norm_bti * dnorm_norm_mean; // term 3\n        dval *= rstd_bt; // final scale\n        dinp_bt[i] += dval;\n    }\n    __syncthreads();\n\n    // write to global memory\n\tfor(int i = threadIdx.x; i < C; i+= blockDim.x){\n        atomicAdd(&dbias[i], dbias_shared[i]);\n        atomicAdd(&dweight[i], dweight_shared[i]);\n\t}\n}\n\n__global__ void softmax_autoregressive_backward_kernel(float* dpreatt, const float* datt, const float* att,\n                                                       int B, int T, int C, float scale) {\n    constexpr const int BlockSize = 256;\n    constexpr int T_per_block = 4;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    __shared__ float block_acc[32];\n\n    int idx = blockIdx.y;\n    // go through blocks in reverse order, so the slowest block starts first\n    int t0 = T - 1 - T_per_block*blockIdx.x;\n\n    att += idx * T * T;\n    datt += idx * T * T;\n    dpreatt += idx * T * T;\n\n    if (warp.meta_group_rank() == 0) {\n        block_acc[warp.thread_rank()] = 0;\n    }\n\n    for(int to = 0; to < T_per_block; ++to) {\n        int t = t0 - to;\n        if(t < 0) return;\n        const float* att_bth = att + t * T;\n        const float* datt_bth = datt + t * T;\n        float* dpreatt_bth = dpreatt + t * T;\n\n        float local_sum = 0;\n        for (int t2 = block.thread_rank(); t2 <= t; t2 += BlockSize) {\n            local_sum += att_bth[t2] * datt_bth[t2];\n        }\n\n        block_acc[warp.meta_group_rank()] = cg::reduce(warp, local_sum, cg::plus<float>{});\n        block.sync();\n        local_sum = cg::reduce(warp, block_acc[warp.thread_rank()], cg::plus<float>{});\n\n        for (int t3 = block.thread_rank(); t3 <= t; t3 += BlockSize) {\n            // don't touch the cache. Some parts will still be here from the previous loop, and\n            // we want to exploit those.\n            float acc = __ldcs(att_bth + t3) * (__ldcs(datt_bth + t3) - local_sum);\n            __stcs(dpreatt_bth + t3, scale * acc);\n        }\n    }\n}\n\n// Implements linear interpolation using only two floating-point operations (as opposed to three in a naive implementation).\n// Reference: https://developer.nvidia.com/blog/lerp-faster-cuda\n__device__ inline float lerp(float start, float end, float weight) {\n    return fma(weight, end, fma(-weight, start, start));\n}\n\n__global__ void adamw_kernel2(float* params_memory, float* grads_memory, float* m_memory, float* v_memory, long num_parameters,\n                              float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) {\n   int i = blockIdx.x * blockDim.x + threadIdx.x;\n   if (i >= num_parameters) return;  // guard\n   float grad = grads_memory[i];\n   float m = m_memory[i];\n   float v = v_memory[i];\n   // update the first moment (momentum)\n   m = lerp(grad, m, beta1);\n   m_memory[i] = m;\n   // update the second moment (RMSprop)\n   v = lerp(grad * grad, v, beta2);\n   v_memory[i] = v;\n   m /= beta1_correction;  // m_hat\n   v /= beta2_correction;  // v_hat\n   params_memory[i] -= learning_rate * (m / (sqrtf(v) + eps) + weight_decay * params_memory[i]);\n}\n\nstruct SoftmaxParams {\n    float Scale;\n    float Offset;\n};\n\n\n__device__ SoftmaxParams prepare_softmax_blockwide_nofloat4(cg::thread_block_tile<32>& warp,\n                                                   int idx, const float* inp, int V, int P) {\n    // same but not float4\n    // one row of inp, i.e. inp[idx, :] of shape (V,)\n\n    const float* x = inp + idx * P;\n    float thread_maxval = -INFINITY;\n    float thread_sumval = 0.0f;\n    // do the loop in reverse to maximise probability of L2 cache hits\n    // so even small L2s get some hits on the 2nd read of the same thread\n    for (int i = V + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) {\n        float v = x[i];\n        float old_maxval = thread_maxval;\n        thread_maxval = fmaxf(thread_maxval, v);\n        thread_sumval *= expf((old_maxval - thread_maxval));\n        thread_sumval += expf(v - thread_maxval);\n    }\n\n    // two reductions of up to 1024 threads:\n    // 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle)\n    // this results in much cleaner assembly than a multi-warp cg::reduce\n    __shared__ float shared_maxval[32];\n    __shared__ float shared_sumval[32];\n    int num_warps = blockDim.x / 32;\n    int warp_id = threadIdx.x / 32;\n    int lane_id = threadIdx.x % 32;\n\n    // reduce maxval within each warp\n    float warp_maxval = cg::reduce(warp, thread_maxval, cg::greater<float>{});\n    // thread 0 in each warp writes to shared memory\n    if (lane_id == 0) { shared_maxval[warp_id] = warp_maxval; }\n    __syncthreads();\n    // each thread now loads the maxval across previous warps\n    // if the thread is \"out of range\" of data, use -FLT_MAX as the maxval\n    warp_maxval = (lane_id < num_warps) ? shared_maxval[lane_id] : -FLT_MAX;\n    // now reduce the maxval among the warp threads\n    float block_maxval = cg::reduce(warp, warp_maxval, cg::greater<float>{});\n    // each thread uses maxval to scale sumval to avoid numerical instability / overflow\n    thread_sumval *= expf(thread_maxval - block_maxval);\n    // (warp-level) reduce sumval, thread 0 in each warp saves result in shared memory\n    float warp_sumval = cg::reduce(warp, thread_sumval, cg::plus<float>{});\n    if (lane_id == 0) { shared_sumval[warp_id] = warp_sumval; }\n    __syncthreads();\n    // same strategy, now reduce sumval across warps\n    warp_sumval = (lane_id < num_warps) ? shared_sumval[lane_id] : 0.0f;\n    float block_sumval = cg::reduce(warp, warp_sumval, cg::plus<float>{});\n    // return the softmax parameters\n    return SoftmaxParams{1.f / block_sumval, block_maxval};\n}\n\n// same as 2 but not using float4 (see dev/cuda/classifier_fused.cu)\n// will _update_ logits to logit gradients\n__global__ void fused_classifier_kernel3(float* logits, float* losses, float* probs,\n                                         const float* dlosses, const int* targets,\n                                         int B, int T, int V, int P) {\n    namespace cg = cooperative_groups;\n    cg::thread_block block = cg::this_thread_block();\n    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);\n    int idx = blockIdx.x;\n    int ix = targets[idx];\n\n    // softmax (reading B * T * V, same logits read again below, hopefully still in cache)\n    SoftmaxParams sp = prepare_softmax_blockwide_nofloat4(warp, idx, logits, V, P);\n\n    // calculate the probability needed for the loss and update (single-threaded)\n    if(threadIdx.x == 0) {\n        float prob = expf(logits[idx * P + ix] - sp.Offset) * sp.Scale;\n        losses[idx] = -logf(prob);\n    }\n\n    // very sensible default for dlosses is 1/(B*T), which is the uniform loss\n    float dloss = dlosses != NULL ? dlosses[idx] : 1.0f / (B*T);\n    // calculate the gradients directly, saves bandwidth from probs during training\n    // but also supports writing probs for inference-only and debugging\n    const float* logits_vec = logits + idx * P;\n    for (int i = threadIdx.x; i < V; i += blockDim.x) {\n        // this is the 2nd read of logits after the one in prepare_softmax2\n        // this data will never be needed again, so we reduce cache persistence\n        float v = __ldcs(&logits_vec[i]);\n        float prob = expf(v - sp.Offset) * sp.Scale;\n        if (probs != NULL) {\n            probs[idx * P + i] = prob;\n        }\n        float indicator = (i == ix) ? 1.0f : 0.0f;\n        logits[idx * P + i] = (prob - indicator) * dloss;\n    }\n}\n\n__device__ float4 ld_vec(const float* address) {\n    return *reinterpret_cast<const float4*>(address);\n}\n\n__device__ void st_vec(float* address, float4 val) {\n    *reinterpret_cast<float4*>(address) = val;\n}\n\n__global__ void __launch_bounds__(16*16, 2) matmul_forward_kernel4(float* out,\n                                                                   const float* inp, const float* weight, const float* bias,\n                                                                   int C, int OC) {\n    // out is (B,T,OC). OC is short for \"output channels\", e.g. OC = 4 * C\n    // inp is (B,T,C), weight is (OC, C), bias is (OC)\n    // each thread handles 8x8 elements; each block 128 by 128 elements.\n    int oc = 8*(blockIdx.y * blockDim.y + threadIdx.y);\n\n    // buffers to cache chunks of the input matrices\n    __shared__ float lhs_s[128][32];\n    __shared__ float rhs_s[128][32];\n\n    // adjust our pointers for the current block\n    inp += 128 * blockIdx.x * C;\n    weight += 128 * blockIdx.y * C;\n    out += 128 * blockIdx.x * OC + 128 * blockIdx.y;\n\n    float vals[8][8] = {};\n    if(bias != NULL) {\n        for (int i = 0; i < 8; i++) {\n            for (int j = 0; j < 8; j += 4) {\n                float4 b = ld_vec(bias + oc + j);\n                vals[i][j+0] = b.x;\n                vals[i][j+1] = b.y;\n                vals[i][j+2] = b.z;\n                vals[i][j+3] = b.w;\n            }\n        }\n    }\n\n    int si_start = 4*(16 * threadIdx.y + threadIdx.x);\n    for (int so = 0; so < C; so += 32) {\n        __syncthreads();\n        int xmod8 = threadIdx.x % 8;\n        int xby8 = threadIdx.x / 8;\n        int xo = 4 * xmod8;\n        for(int y = 2 * threadIdx.y + xby8; y < 128; y += 32) {\n            st_vec(&lhs_s[y][xo], ld_vec(inp + y * C + so + xo));\n            st_vec(&rhs_s[y][xo], ld_vec(weight + y * C + so + xo));\n        }\n        __syncthreads();\n\n        for (int si = si_start; si < si_start + 32; si += 4) {\n            float4 rhs[8];\n            for (int u = 0; u < 8; ++u) {\n                rhs[u] = ld_vec(&rhs_s[u + 8 * threadIdx.y][si % 32]);\n            }\n\n            for (int ii = 0; ii < 8; ++ii) {\n                float4 lhs = ld_vec(&lhs_s[ii + 8 * threadIdx.x][si % 32]);\n                for (int ji = 0; ji < 8; ++ji) {\n                    vals[ii][ji] += lhs.x * rhs[ji].x;\n                    vals[ii][ji] += lhs.y * rhs[ji].y;\n                    vals[ii][ji] += lhs.z * rhs[ji].z;\n                    vals[ii][ji] += lhs.w * rhs[ji].w;\n                }\n            }\n        }\n    }\n\n    for (int i = 0; i < 8; ++i) {\n        for (int j = 0; j < 8; j += 4) {\n            float4 result;\n            result.x = vals[i][j + 0];\n            result.y = vals[i][j + 1];\n            result.z = vals[i][j + 2];\n            result.w = vals[i][j + 3];\n            st_vec(out + (8*threadIdx.x+i) * OC + 8*threadIdx.y + j, result);\n        }\n    }\n}\n\n\n// ----------------------------------------------------------------------------\n// kernel launchers\n\nvoid encoder_forward(float* out,\n                     const int* inp, const float* wte, const float* wpe,\n                     int B, int T, int C) {\n    assert(C % 4 == 0);\n    const int block_size = 512;\n    const int N = B * T * C;\n    const int grid_size = CEIL_DIV(N / 4, block_size);\n    encoder_forward_kernel3<<<grid_size, block_size>>>((float4*) out, inp, (float4*) wte, (float4*) wpe, B, T, C);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid encoder_backward(float* dwte, float* dwpe,\n                    const float* dout, const int* inp,\n                    int B, int T, int C) {\n    const int N = B * T * C;\n    const int block_size = 256;\n    const int grid_size = CEIL_DIV(N, block_size);\n    encoder_backward_kernel<<<grid_size, block_size>>>(dwte, dwpe, dout, inp, B, T, C);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid layernorm_forward(float* out, float* mean, float* rstd,\n                       float* inp, float* weight, float* bias,\n                       int B, int T, int C) {\n    const int block_size = 512;\n    const int N = B * T;\n    const int grid_size = CEIL_DIV(N * 32, block_size);\n    layernorm_forward_kernel3<<<grid_size, block_size>>>(out, mean, rstd, inp, weight, bias, N, C);\n    cudaCheck(cudaGetLastError());\n}\n\n// kernel 1 is the most naive matmul kernel\nvoid matmul_forward(float* out,\n                    const float* inp, const float* weight, const float* bias,\n                    int B, int T, int C, int OC) {\n    // out is (B,T,OC). OC is short for \"output channels\", e.g. OC = 4 * C\n    // inp is (B,T,C), weight is (OC, C), bias is (OC)\n    int sqrt_block_size = 16;\n\n    dim3 gridDim(CEIL_DIV(B * T, 8*sqrt_block_size), CEIL_DIV(OC, 8*sqrt_block_size));\n    dim3 blockDim(sqrt_block_size, sqrt_block_size);\n    matmul_forward_kernel4<<<gridDim, blockDim>>>(out, inp, weight, bias, C, OC);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid attention_forward(float* out, float* qkvr, float* att,\n                       float* inp,\n                       int B, int T, int C, int NH) {\n    // Note: `inp` is not needed for backward pass, so we re-use it as a scratch buffer.\n    // Its contents will be overwritten by this function.\n    const int block_size = 256;\n    const int softmax_block_size = 256;\n\n    // inp is (B, T, 3C) QKV\n    // preatt, att are (B, NH, T, T)\n    // output is (B, T, C)\n    int HS = C / NH; // head size\n\n    // permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS)\n    float *q, *k, *v;\n    q = qkvr + 0 * B * T * C;\n    k = qkvr + 1 * B * T * C;\n    v = qkvr + 2 * B * T * C;\n    int total_threads = B * NH * T * HS;\n    int num_blocks = CEIL_DIV(total_threads, block_size);\n    permute_kernel<<<num_blocks, block_size>>>(q, k, v, inp, B, T, NH, HS);\n    cudaCheck(cudaGetLastError());\n\n    // batched matrix multiply with cuBLAS\n    const float alpha = 1.0f;\n    const float beta = 0.0f;\n    float* preatt = inp;\n    cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, &alpha, k, HS, T * HS, q, HS, T * HS, &beta, preatt, T, T * T, B * NH));\n\n    // multiply all elements of preatt elementwise by scale\n    float scale = 1.0 / sqrtf(HS);\n    int grid_size = CEIL_DIV(B * NH * T * 32, softmax_block_size);\n    softmax_forward_kernel5<<<grid_size, softmax_block_size>>>(att, scale, preatt, B * NH, T);\n    cudaCheck(cudaGetLastError());\n\n    // new approach: first cuBLAS another batched matmul\n    float* vaccum = inp;\n    // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)\n    cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, &alpha, v, HS, T * HS, att, T, T * T, &beta, vaccum, HS, T * HS, B * NH));\n\n    // now unpermute\n    // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side\n    num_blocks = CEIL_DIV(B * T * C, block_size);\n    unpermute_kernel<<<num_blocks, block_size>>>(vaccum, out, B, T, NH, HS);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid residual_forward(float* out, float* inp1, float* inp2, int N) {\n    const int block_size = 256;\n    const int grid_size = CEIL_DIV(N, block_size);\n    residual_forward_kernel<<<grid_size, block_size>>>(out, inp1, inp2, N);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid gelu_forward(float* out, const float* inp, int N) {\n    const int block_size = 128;\n    const int grid_size = CEIL_DIV(N, block_size);\n    gelu_forward_kernel<<<grid_size, block_size>>>(out, inp, N);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid gelu_backward(float* dinp, const float* inp, const float* dout, const int N) {\n    const int block_size = 128;\n    const int grid_size = CEIL_DIV(N, block_size);\n    gelu_backward_kernel<<<grid_size, block_size>>>(dinp, inp, dout, N);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid matmul_backward(float* dinp, float* dweight, float* dbias,\n                     float* dout, float* inp, float* weight,\n                     int B, int T, int C, int OC) {\n    float one = 1.0f;\n    float zero = 0.0f;\n    // backward to input, uses = in the backward pass (set the gradient)\n    cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, C, B*T, OC, &one, weight, C, dout, OC, &zero, dinp, C));\n    // backward to weight, uses += in the backward pass (accumulate the gradient)\n    cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, C, OC, B*T, &one, inp, C, dout, OC, &one, dweight, C));\n    // backward to bias, if given, does a +=\n    if (dbias != NULL) {\n        const int block_size = 1024;\n        const int grid_size = OC / 32; // for now, OC must be divisible by 32 for this kernel to work\n        matmul_backward_bias_kernel4<<<grid_size, block_size, block_size * sizeof(float)>>>(dbias, dout, B, T, OC);\n        cudaCheck(cudaGetLastError());\n    }\n}\n\nvoid layernorm_backward(float* dinp, float* dweight, float* dbias,\n                        const float* dout, const float* inp, const  float* weight, const float* mean, const float* rstd,\n                        int B, int T, int C) {\n    const int block_size = 512;\n    const int N = B * T;\n    const int grid_size = CEIL_DIV(32*N, block_size);\n    size_t shared_mem_size = 2 * C * sizeof(float);\n    layernorm_backward_kernel2<<<grid_size, block_size, shared_mem_size>>>(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C);\n    cudaCheck(cudaGetLastError());\n}\n\n// the sequence of transformations in this compound op is:\n// inp (B,T,3C) -> qkvr (B,T,3C) -> preatt (B,NH,T,T) -> att (B,NH,T,T) -> vaccum (B,T,C) -> out (B,T,C)\nvoid attention_backward(float* dinp, float* dqkvr, float* dpreatt, float* datt, float* scratch,\n                        const float* dout,\n                        const float* qkvr, const float* att,\n                        int B, int T, int C, int NH) {\n    const int block_size = 256;\n    int HS = C / NH; // head size\n    const float one = 1.0f;\n    const float zero = 0.0f; // note beta = 1.0f so that we accumulate gradients (+=)\n    // unpack convenience pointers into q, k, v\n    const float *q, *k, *v;\n    q = qkvr + 0 * B * T * C;\n    k = qkvr + 1 * B * T * C;\n    v = qkvr + 2 * B * T * C;\n    float *dq, *dk, *dv;\n    dq = dqkvr + 0 * B * T * C;\n    dk = dqkvr + 1 * B * T * C;\n    dv = dqkvr + 2 * B * T * C;\n    // backward through the unpermute operation\n    int num_blocks = CEIL_DIV(B * T * C, block_size);\n    unpermute_kernel_backward<<<num_blocks, block_size>>>(scratch, dout, B, T, NH, HS);\n    cudaCheck(cudaGetLastError());\n    // backward into datt\n    cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, &one, v, HS, T * HS, scratch, HS, T * HS, &zero, datt, T, T * T, B * NH));\n    // backward into dv\n    cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, HS, T, T, &one, scratch, HS, T * HS, att, T, T * T, &zero, dv, HS, T * HS, B * NH));\n    // backward into preatt\n    int hs = C / NH; // head size\n    float scale = 1.0f / sqrtf(hs);\n    softmax_autoregressive_backward_kernel<<<dim3(T / 4, B * NH), 256>>>(dpreatt, datt, att, B, T, C, scale);\n    cudaCheck(cudaGetLastError());\n    // backward into q\n    cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, &one, k, HS, T * HS, dpreatt, T, T * T, &zero, dq, HS, T * HS, B * NH));\n    // backward into k\n    cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, HS, T, T, &one, q, HS, T * HS, dpreatt, T, T * T, &zero, dk, HS, T * HS, B * NH));\n    // backward into inp\n    num_blocks = CEIL_DIV(B * NH * T * HS, block_size);\n    permute_kernel_backward<<<num_blocks, block_size>>>(dinp, dq, dk, dv, B, T, NH, HS);\n    cudaCheck(cudaGetLastError());\n}\n\n// replaces logits with logit gradients\nvoid fused_classifier3(float* logits, float* losses,\n                      const float* dlosses, const int* targets,\n                      int B, int T, int V, int P) {\n    const int block_size = 1024;\n    const int N = B * T;\n    const int grid_size = N;\n    fused_classifier_kernel3<<<grid_size, block_size>>>(logits, losses, NULL, dlosses, targets, B, T, V, P);\n    cudaCheck(cudaGetLastError());\n}\n\n// ----------------------------------------------------------------------------\n// GPT-2 model definition\n\ntypedef struct {\n    int max_seq_len; // max sequence length, e.g. 1024\n    int vocab_size; // vocab size, e.g. 50257\n    int padded_vocab_size; // padded to e.g. %128==0, 50304\n    int num_layers; // number of layers, e.g. 12\n    int num_heads; // number of heads in attention, e.g. 12\n    int channels; // number of channels, e.g. 768\n} GPT2Config;\n\n// the parameters of the model\n#define NUM_PARAMETER_TENSORS 16\ntypedef struct {\n    float* wte; // (V, C)\n    float* wpe; // (maxT, C)\n    float* ln1w; // (L, C)\n    float* ln1b; // (L, C)\n    float* qkvw; // (L, 3*C, C)\n    float* qkvb; // (L, 3*C)\n    float* attprojw; // (L, C, C)\n    float* attprojb; // (L, C)\n    float* ln2w; // (L, C)\n    float* ln2b; // (L, C)\n    float* fcw; // (L, 4*C, C)\n    float* fcb; // (L, 4*C)\n    float* fcprojw; // (L, C, 4*C)\n    float* fcprojb; // (L, C)\n    float* lnfw; // (C)\n    float* lnfb; // (C)\n} ParameterTensors;\n\nvoid fill_in_parameter_sizes(size_t* param_sizes, GPT2Config config) {\n    int Vp = config.padded_vocab_size;\n    int C = config.channels;\n    int maxT = config.max_seq_len;\n    int L = config.num_layers;\n    param_sizes[0] = Vp * C; // wte\n    param_sizes[1] = maxT * C; // wpe\n    param_sizes[2] = L * C; // ln1w\n    param_sizes[3] = L * C; // ln1b\n    param_sizes[4] = L * (3 * C) * C; // qkvw\n    param_sizes[5] = L * (3 * C); // qkvb\n    param_sizes[6] = L * C * C; // attprojw\n    param_sizes[7] = L * C; // attprojb\n    param_sizes[8] = L * C; // ln2w\n    param_sizes[9] = L * C; // ln2b\n    param_sizes[10] = L * (4 * C) * C; // fcw\n    param_sizes[11] = L * (4 * C); // fcb\n    param_sizes[12] = L * C * (4 * C); // fcprojw\n    param_sizes[13] = L * C; // fcprojb\n    param_sizes[14] = C; // lnfw\n    param_sizes[15] = C; // lnfb\n}\n\n// allocate memory for the parameters and point the individual tensors to the right places\nfloat* malloc_and_point_parameters(ParameterTensors* params, size_t* param_sizes, int on_device) {\n    // on_device: 0 = CPU, 1 = GPU\n    // calculate the number of parameters\n    size_t num_parameters = 0;\n    for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {\n        num_parameters += param_sizes[i];\n    }\n    // malloc all parameters all at once on the device\n    float* params_memory;\n    if (on_device) {\n        cudaCheck(cudaMalloc((void**)&params_memory, num_parameters * sizeof(float)));\n    } else {\n        params_memory = (float*)mallocCheck(num_parameters * sizeof(float));\n    }\n    // assign all the tensors their place in the array\n    float** ptrs[] = {\n        &params->wte, &params->wpe, &params->ln1w, &params->ln1b, &params->qkvw, &params->qkvb,\n        &params->attprojw, &params->attprojb, &params->ln2w, &params->ln2b, &params->fcw, &params->fcb,\n        &params->fcprojw, &params->fcprojb, &params->lnfw, &params->lnfb\n    };\n    float* params_memory_iterator = params_memory;\n    for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {\n        *(ptrs[i]) = params_memory_iterator;\n        params_memory_iterator += param_sizes[i];\n    }\n    return params_memory;\n}\n\n#define NUM_ACTIVATION_TENSORS 21\ntypedef struct {\n    float* encoded; // (B, T, C)\n    float* ln1; // (L, B, T, C)\n    float* ln1_mean; // (L, B, T)\n    float* ln1_rstd; // (L, B, T)\n    float* atty; // (L, B, T, C)\n    float* att; // (L, B, NH, T, T)\n    float* attproj; // (L, B, T, C)\n    float* residual2; // (L, B, T, C)\n    float* ln2; // (L, B, T, C)\n    float* ln2_mean; // (L, B, T)\n    float* ln2_rstd; // (L, B, T)\n    float* fch; // (L, B, T, 4*C)\n    float* fch_gelu; // (L, B, T, 4*C)\n    float* fcproj; // (L, B, T, C)\n    float* residual3; // (L, B, T, C)\n    float* lnf; // (B, T, C)\n    float* lnf_mean; // (B, T)\n    float* lnf_rstd; // (B, T)\n\n    float* losses; // (B, T)\n    // adding these two compared to the CPU .c code, needed for attention kernel as buffers\n    float* qkvr; // (L, B, T, 3*C)\n    // in inference mode, this buffer will store the logits\n    // in training mode, this buffer will contain the *gradients* of the logits.\n    // during the processing of transformer blocks, we will also use this as a\n    // general scratchpad buffer. Allocation is made large enough to hold (B, T, 3C),\n    // (B, NH, T, T), and (B, T, V) shaped tensors.\n    float* output;\n} ActivationTensors;\n\nvoid fill_in_activation_sizes(size_t* act_sizes, int B, int T, GPT2Config config) {\n    size_t Vp = config.padded_vocab_size;\n    size_t L = config.num_layers;\n    size_t NH = config.num_heads;\n    size_t C = config.channels;\n    act_sizes[0] = B * T * C; // encoded\n    act_sizes[1] = L * B * T * C; // ln1\n    act_sizes[2] = L * B * T; // ln1_mean\n    act_sizes[3] = L * B * T; // ln1_rstd\n    act_sizes[4] = L * B * T * C; // atty\n    act_sizes[5] = L * B * NH * T * T; // att\n    act_sizes[6] = L * B * T * C; // attproj\n    act_sizes[7] = L * B * T * C; // residual2\n    act_sizes[8] = L * B * T * C; // ln2\n    act_sizes[9] = L * B * T; // ln2_mean\n    act_sizes[10] = L * B * T; // ln2_rstd\n    act_sizes[11] = L * B * T * 4*C; // fch\n    act_sizes[12] = L * B * T * 4*C; // fch_gelu\n    act_sizes[13] = L * B * T * C; // fcproj\n    act_sizes[14] = L * B * T * C; // residual3\n    act_sizes[15] = B * T * C; // lnf\n    act_sizes[16] = B * T; // lnf_mean\n    act_sizes[17] = B * T; // lnf_rstd\n    act_sizes[18] = B * T; // losses\n    act_sizes[19] = L * B * T * 3*C; // qkvr\n    act_sizes[20] = B * T * max(3*C, max(NH*T, Vp)); // output / scratch\n}\n\n// Backward pass is conceptually quite different from forward, because we can discard\n// the activations of a layer as soon as we're done with it. This lets us aggressively\n// reuse memory, so that we need far fewer tensors for backward state.\n#define NUM_BACKWARD_TENSORS 3\ntypedef struct {\n    float* bt4c; // (B, T, 4*C)\n    float* preatt; // (B, NH, T, T)\n    float* residual3; // (B, T, C)\n} GradActTensors;\n\n\nvoid fill_in_grad_act_sizes(size_t* act_sizes, int B, int T, GPT2Config config) {\n    size_t NH = config.num_heads;\n    size_t C = config.channels;\n    act_sizes[0] = B * T * 4 * C; // bt4c\n    act_sizes[1] = B * NH * T * T; // preatt\n    act_sizes[2] = B * T * C; // residual3\n}\n\n\nfloat* malloc_and_point(float** targets[], const size_t* act_sizes, int n) {\n    size_t num_activations = 0;\n    for (size_t i = 0; i < n; i++) {\n        num_activations += act_sizes[i];\n    }\n    float* acts_memory;\n    cudaCheck(cudaMalloc((void**)&acts_memory, num_activations * sizeof(float)));\n    float* acts_memory_iterator = acts_memory;\n    for (size_t i = 0; i < n; i++) {\n        *(targets[i]) = acts_memory_iterator;\n        acts_memory_iterator += act_sizes[i];\n    }\n    return acts_memory;\n}\n\nfloat* malloc_and_point_activations(ActivationTensors* acts, const size_t* act_sizes) {\n    float** ptrs[] = {\n        &acts->encoded, &acts->ln1, &acts->ln1_mean, &acts->ln1_rstd, &acts->atty,\n        &acts->att, &acts->attproj, &acts->residual2, &acts->ln2, &acts->ln2_mean,\n        &acts->ln2_rstd, &acts->fch, &acts->fch_gelu, &acts->fcproj, &acts->residual3, &acts->lnf,\n        &acts->lnf_mean, &acts->lnf_rstd, &acts->losses, &acts->qkvr, &acts->output\n    };\n    return malloc_and_point(ptrs, act_sizes, NUM_ACTIVATION_TENSORS);\n}\n\nfloat* malloc_and_point_backward(GradActTensors* acts, const size_t* act_sizes) {\n    float** ptrs[] = {\n        &acts->bt4c, &acts->preatt, &acts->residual3\n    };\n    return malloc_and_point(ptrs, act_sizes, NUM_BACKWARD_TENSORS);\n}\n\ntypedef struct {\n    GPT2Config config;\n    // the weights of the model, and their sizes\n    ParameterTensors params;\n    size_t param_sizes[NUM_PARAMETER_TENSORS];\n    float* params_memory;\n    size_t num_parameters;\n    // gradients of the weights\n    ParameterTensors grads;\n    float* grads_memory;\n    // buffers for the AdamW optimizer\n    float* m_memory;\n    float* v_memory;\n    // the activations of the model, and their sizes\n    ActivationTensors acts;\n    size_t act_sizes[NUM_ACTIVATION_TENSORS];\n    float* acts_memory;\n    size_t num_activations;\n    // gradients of the activations\n    GradActTensors grads_acts;\n    size_t num_grad_acts;\n    float* grads_acts_memory;\n    // other run state configuration\n    int batch_size; // the batch size (B) of current forward pass\n    int seq_len; // the sequence length (T) of current forward pass\n    int* inputs; // the input tokens for the current forward pass\n    int* targets; // the target tokens for the current forward pass\n    float mean_loss; // after a forward pass with targets, will be populated with the mean loss\n    float* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost\n} GPT2;\n\nvoid gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {\n\n    // read in model from a checkpoint file\n    FILE *model_file = fopenCheck(checkpoint_path, \"rb\");\n    int model_header[256];\n    freadCheck(model_header, sizeof(int), 256, model_file);\n    if (model_header[0] != 20240326) { fprintf(stderr, \"Bad magic model file\\n\"); exit(EXIT_FAILURE); }\n    if (model_header[1] != 3) {\n        // was bumped from 1 -> 3 to incorporate the padded vocab size\n        fprintf(stderr, \"Bad version in model file\\n\");\n        fprintf(stderr, \"---> HINT: try to re-run `python train_gpt2.py`\\n\");\n        exit(EXIT_FAILURE);\n    }\n\n    // read in hyperparameters\n    model->config.max_seq_len = model_header[2];\n    model->config.vocab_size = model_header[3];\n    model->config.num_layers = model_header[4];\n    model->config.num_heads = model_header[5];\n    model->config.channels = model_header[6];\n    model->config.padded_vocab_size = model_header[7];\n\n    // allocate space for all the parameters and read them in\n    fill_in_parameter_sizes(model->param_sizes, model->config);\n\n    // count the number of parameters\n    size_t num_parameters = 0;\n    for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {\n        num_parameters += model->param_sizes[i];\n    }\n    model->num_parameters = num_parameters;\n\n    // create memory for model parameters on the device\n    model->params_memory = malloc_and_point_parameters(&model->params, model->param_sizes, 1);\n\n    // read in all the parameters from file and copy them to device\n    float* params_memory_cpu = (float*)mallocCheck(num_parameters * sizeof(float));\n    freadCheck(params_memory_cpu, sizeof(float), num_parameters, model_file);\n    cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, num_parameters * sizeof(float), cudaMemcpyHostToDevice));\n    free(params_memory_cpu);\n    fcloseCheck(model_file);\n\n    // other inits\n    model->acts_memory = NULL;\n    model->grads_memory = NULL;\n    model->m_memory = NULL;\n    model->v_memory = NULL;\n    model->grads_acts_memory = NULL;\n    model->inputs = NULL;\n    model->targets = NULL;\n    model->cpu_losses = NULL;\n    model->batch_size = 0;\n    model->seq_len = 0;\n    model->mean_loss = -1.0f; // -1.0f will designate no loss\n}\n\nvoid gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) {\n    // targets are optional and could be NULL\n\n    // ensure the model was initialized or error out\n    if (model->params_memory == NULL) {\n        printf(\"Error: model was not initialized properly.\\n\");\n        exit(EXIT_FAILURE);\n    }\n\n    // convenience parameters\n    int V = model->config.vocab_size;\n    int Vp = model->config.padded_vocab_size;\n    int L = model->config.num_layers;\n    int NH = model->config.num_heads;\n    int C = model->config.channels;\n\n    // validate inputs, all indices must be in the range [0, V)\n    for(int i = 0; i < B * T; i++) {\n        assert(0 <= inputs[i] && inputs[i] < V);\n        if (targets != NULL) {\n            assert(0 <= targets[i] && targets[i] < V);\n        }\n    }\n\n    // allocate space for all the activations if needed (done here, lazily)\n    if(model->acts_memory == NULL) {\n        // record the current B,T as well\n        model->batch_size = B;\n        model->seq_len = T;\n        // and now allocate the space\n        fill_in_activation_sizes(model->act_sizes, B, T, model->config);\n        size_t num_activations = 0;\n        for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {\n            num_activations += model->act_sizes[i];\n        }\n        model->num_activations = num_activations;\n        model->acts_memory = malloc_and_point_activations(&model->acts, model->act_sizes);\n        printf(\"allocated %zu MiB for activations\\n\", (num_activations * sizeof(float)) >> 20); // >> 20 is /(1024*1024)\n        // also create memory for caching inputs and targets\n        cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int)));\n        cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int)));\n        cudaCheck(cudaMallocHost((void**)&model->cpu_losses, B * T * sizeof(float)));\n    } else {\n        // validate B,T is consistent with how we've allocated the memory before\n        // in principle we could get more clever here in the future, for now this is safest\n        if (B != model->batch_size || T != model->seq_len) {\n            printf(\"Model: B=%d T=%d, Desired: B=%d T=%d\\n\", model->batch_size, model->seq_len, B, T);\n            exit(EXIT_FAILURE);\n        }\n    }\n\n    // copy inputs/targets to the model\n    cudaCheck(cudaMemcpy(model->inputs, inputs, B * T * sizeof(int), cudaMemcpyHostToDevice));\n    if (targets != NULL) {\n        cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice));\n    }\n\n    // forward pass\n    ParameterTensors params = model->params; // for brevity\n    ActivationTensors acts = model->acts;\n    float* residual;\n    encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C); // encoding goes into residual[0]\n\n    for (int l = 0; l < L; l++) {\n\n        residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;\n\n        // get the pointers of the weights for this layer\n        float* l_ln1w = params.ln1w + l * C;\n        float* l_ln1b = params.ln1b + l * C;\n        float* l_qkvw = params.qkvw + l * 3*C * C;\n        float* l_qkvb = params.qkvb + l * 3*C;\n        float* l_attprojw = params.attprojw + l * C * C;\n        float* l_attprojb = params.attprojb + l * C;\n        float* l_ln2w = params.ln2w + l * C;\n        float* l_ln2b = params.ln2b + l * C;\n        float* l_fcw = params.fcw + l * 4*C * C;\n        float* l_fcb = params.fcb + l * 4*C;\n        float* l_fcprojw = params.fcprojw + l * C * 4*C;\n        float* l_fcprojb = params.fcprojb + l * C;\n\n        // get the pointers of the activations for this layer\n        float* l_ln1 = acts.ln1 + l * B * T * C;\n        float* l_ln1_mean = acts.ln1_mean + l * B * T;\n        float* l_ln1_rstd = acts.ln1_rstd + l * B * T;\n        float* l_qkvr = acts.qkvr + l * B * T * 3*C;\n        float* l_atty = acts.atty + l * B * T * C;\n        float* l_att = acts.att + l * B * NH * T * T;\n        float* l_attproj = acts.attproj + l * B * T * C;\n        float* l_residual2 = acts.residual2 + l * B * T * C;\n        float* l_ln2 = acts.ln2 + l * B * T * C;\n        float* l_ln2_mean = acts.ln2_mean + l * B * T;\n        float* l_ln2_rstd = acts.ln2_rstd + l * B * T;\n        float* l_fch = acts.fch + l * B * T * 4*C;\n        float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C;\n        float* l_fcproj = acts.fcproj + l * B * T * C;\n        float* l_residual3 = acts.residual3 + l * B * T * C;\n        // these are only needed as scratchpads for the forward pass, but\n        // need not be stored for backward\n        float* scratch = acts.output;\n\n        // now do the forward pass\n        layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C);\n        matmul_forward(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C);\n        attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH);\n        matmul_forward(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C);\n        residual_forward(l_residual2, residual, l_attproj, B*T*C);\n        layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C);\n        matmul_forward(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C);\n        gelu_forward(l_fch_gelu, l_fch, B*T*4*C);\n        matmul_forward(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C);\n        residual_forward(l_residual3, l_residual2, l_fcproj, B*T*C);\n    }\n\n    residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3\n    layernorm_forward(acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C);\n    matmul_forward(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp);\n\n    // also forward the cross-entropy loss function if we have the targets\n    if (targets != NULL) {\n        // fused classifier: does the forward pass and first part of the backward pass\n        // we're passing dlosses = NULL, which will default them to 1.0f/(B*T), i.e. uniform loss\n        fused_classifier3(acts.output, acts.losses, NULL, model->targets, B, T, V, Vp);\n        // for convenience also evaluate the mean loss (TODO re-think this compute+sync point)\n        // move the (B,T) losses to CPU\n        cudaCheck(cudaMemcpy(model->cpu_losses, acts.losses, B * T * sizeof(float), cudaMemcpyDeviceToHost));\n        float mean_loss = 0.0f;\n        for (int i=0; i<B*T; i++) { mean_loss += model->cpu_losses[i]; }\n        mean_loss /= B*T;\n        model->mean_loss = mean_loss;\n\n    } else {\n        // if we don't have targets, we don't have loss\n        model->mean_loss = -1.0f;\n    }\n}\n\nvoid gpt2_zero_grad(GPT2 *model) {\n    if (model->grads_acts_memory != NULL) { cudaCheck(cudaMemset(model->grads_acts_memory, 0, model->num_grad_acts * sizeof(float))); }\n    if (model->grads_memory != NULL) { cudaCheck(cudaMemset(model->grads_memory, 0, model->num_parameters * sizeof(float))); }\n}\n\nvoid gpt2_backward(GPT2 *model) {\n\n    // double check we forwarded previously, with targets\n    if (model->mean_loss == -1.0f) {\n        printf(\"Error: must forward with targets before backward\\n\");\n        exit(EXIT_FAILURE);\n    }\n\n    // lazily allocate the memory for gradients of the weights and activations, if needed\n    if (model->grads_memory == NULL) {\n        // allocate buffers for weight gradients\n        model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_sizes, 1);\n        printf(\"allocated %zu MiB for parameter gradients\\n\", (model->num_parameters * sizeof(float)) >> 20);\n        // we're going to be clever for the activations backward pass. we don't need to exactly\n        // mirror the forward pass acrtivations and we will save memory.\n        size_t bw_act_sizes[NUM_ACTIVATION_TENSORS];\n        GPT2Config cfg = model->config;\n        cfg.num_layers = 1; // copy the configuration but override number of layers to 1\n        fill_in_grad_act_sizes(bw_act_sizes, model->batch_size, model->seq_len, cfg);\n        // count up and allocate the space\n        model->grads_acts_memory = malloc_and_point_backward(&model->grads_acts, bw_act_sizes);\n        model->num_grad_acts = 0;\n        for (int i = 0; i < NUM_BACKWARD_TENSORS; i++) {\n            model->num_grad_acts += bw_act_sizes[i];\n        }\n        printf(\"allocated %zu MiB for activation gradients\\n\", (model->num_grad_acts * sizeof(float)) >> 20);\n        // init gradients of parameters and activations to zero\n        gpt2_zero_grad(model);\n    }\n\n    // convenience shortcuts\n    int B = model->batch_size;\n    int T = model->seq_len;\n    int Vp = model->config.padded_vocab_size;\n    int L = model->config.num_layers;\n    int NH = model->config.num_heads;\n    int C = model->config.channels;\n\n    // backward pass: go in the reverse order of the forward pass, and call backward() functions\n    ParameterTensors params = model->params; // for brevity\n    ParameterTensors grads = model->grads;\n    ActivationTensors acts = model->acts;\n    GradActTensors grads_acts = model->grads_acts;\n\n    // we kick off the chain rule by filling in dlosses with 1.0f/(B*T)\n    // this was done in the fused classifier kernel as last step of forward pass\n    // technically that is a small, inline backward() pass of calculating\n    // total, final loss as the mean over all losses over all (B,T) positions in the batch\n    // next: backward the classifier matmul\n    matmul_backward(grads_acts.bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, B, T, C, Vp);\n    // backward the final layernorm\n    float* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3\n    float* dresidual = grads_acts.residual3; // the main buffer holding the gradient in the backward pass\n    layernorm_backward(dresidual, grads.lnfw, grads.lnfb, grads_acts.bt4c, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C);\n\n    // now backward all the layers\n    for (int l = L-1; l >= 0; l--) {\n        residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;\n\n        // get the pointers of the weights for this layer\n        float* l_ln1w = params.ln1w + l * C;\n        float* l_qkvw = params.qkvw + l * 3*C * C;\n        float* l_attprojw = params.attprojw + l * C * C;\n        float* l_ln2w = params.ln2w + l * C;\n        float* l_fcw = params.fcw + l * 4*C * C;\n        float* l_fcprojw = params.fcprojw + l * C * 4*C;\n        // get the pointers of the gradients of the weights for this layer\n        float* dl_ln1w = grads.ln1w + l * C;\n        float* dl_ln1b = grads.ln1b + l * C;\n        float* dl_qkvw = grads.qkvw + l * 3*C * C;\n        float* dl_qkvb = grads.qkvb + l * 3*C;\n        float* dl_attprojw = grads.attprojw + l * C * C;\n        float* dl_attprojb = grads.attprojb + l * C;\n        float* dl_ln2w = grads.ln2w + l * C;\n        float* dl_ln2b = grads.ln2b + l * C;\n        float* dl_fcw = grads.fcw + l * 4*C * C;\n        float* dl_fcb = grads.fcb + l * 4*C;\n        float* dl_fcprojw = grads.fcprojw + l * C * 4*C;\n        float* dl_fcprojb = grads.fcprojb + l * C;\n        // get the pointers of the activations for this layer\n        float* l_ln1 = acts.ln1 + l * B * T * C;\n        float* l_ln1_mean = acts.ln1_mean + l * B * T;\n        float* l_ln1_rstd = acts.ln1_rstd + l * B * T;\n        float* l_qkvr = acts.qkvr + l * B * T * 3*C;\n        float* l_atty = acts.atty + l * B * T * C;\n        float* l_att = acts.att + l * B * NH * T * T;\n        float* l_residual2 = acts.residual2 + l * B * T * C;\n        float* l_ln2 = acts.ln2 + l * B * T * C;\n        float* l_ln2_mean = acts.ln2_mean + l * B * T;\n        float* l_ln2_rstd = acts.ln2_rstd + l * B * T;\n        float* l_fch = acts.fch + l * B * T * 4*C;\n        float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C;\n        // get the pointers of the gradients of the activations for this layer\n        // notice that there is no l *, because we just have a single copy, and keep\n        // re-using this memory in every Transformer block as we calculate backward pass\n\n        // we need a B x T x C buffer; thankfully, the forward activation for lnf isn't needed anymore,\n        // so we can co-opt it here.\n        float* dl_btc = acts.lnf;\n        float* dl_bt4c = grads_acts.bt4c;\n        float* dl_preatt = grads_acts.preatt;\n\n        // re-use scratch buffer of the forward pass\n        float* scratch = acts.output;\n\n        // backprop this layer\n        matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, B, T, 4*C, C);\n        gelu_backward(dl_bt4c, l_fch, dl_bt4c, B*T*4*C);\n        matmul_backward(dl_btc, dl_fcw, dl_fcb, dl_bt4c, l_ln2, l_fcw, B, T, C, 4 * C);\n        // layernorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above\n        layernorm_backward(dresidual, dl_ln2w, dl_ln2b, dl_btc, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C);\n        matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, B, T, C, C);\n        // we more B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory\n        float* buffer_a = l_atty;\n        float* buffer_b = l_fch;        // this is B x T x 4C, so even larger than what we need\n\n        attention_backward(dl_bt4c, buffer_b, dl_preatt, scratch, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH);\n        matmul_backward(dl_btc, dl_qkvw, dl_qkvb, dl_bt4c, l_ln1, l_qkvw, B, T, C, 3 * C);\n        // layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above\n        layernorm_backward(dresidual, dl_ln1w, dl_ln1b, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C);\n    }\n    encoder_backward(grads.wte, grads.wpe, dresidual, model->inputs, B, T, C);\n}\n\nvoid gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, int t) {\n    // reference: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html\n\n    // lazily allocate the memory for m_memory and v_memory\n    if (model->m_memory == NULL) {\n        cudaCheck(cudaMalloc((void**)&model->m_memory, model->num_parameters * sizeof(float)));\n        cudaCheck(cudaMalloc((void**)&model->v_memory, model->num_parameters * sizeof(float)));\n        cudaCheck(cudaMemset(model->m_memory, 0, model->num_parameters * sizeof(float)));\n        cudaCheck(cudaMemset(model->v_memory, 0, model->num_parameters * sizeof(float)));\n        printf(\"allocated %zu MiB for AdamW optimizer state m\\n\", (model->num_parameters * sizeof(float)) >> 20);\n        printf(\"allocated %zu MiB for AdamW optimizer state v\\n\", (model->num_parameters * sizeof(float)) >> 20);\n    }\n\n    int block_size = 512;\n    int num_blocks = CEIL_DIV(model->num_parameters, block_size);\n    float beta1_correction = 1.0f - powf(beta1, t);\n    float beta2_correction = 1.0f - powf(beta2, t);\n    adamw_kernel2<<<num_blocks, block_size>>>(model->params_memory, model->grads_memory, model->m_memory, model->v_memory,\n                                              model->num_parameters,\n                                              learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay);\n    cudaCheck(cudaGetLastError());\n}\n\nvoid gpt2_free(GPT2 *model) {\n    cudaCheck(cudaFree(model->params_memory));\n    cudaCheck(cudaFree(model->grads_memory));\n    cudaCheck(cudaFree(model->m_memory));\n    cudaCheck(cudaFree(model->v_memory));\n    cudaCheck(cudaFree(model->acts_memory));\n    cudaCheck(cudaFree(model->grads_acts_memory));\n    cudaCheck(cudaFree(model->inputs));\n    cudaCheck(cudaFree(model->targets));\n    cudaFreeHost(model->cpu_losses);\n}\n\n#ifndef TESTING\n// if we are TESTING (see test_gpt2.cu), we'll skip the int main below\n// ----------------------------------------------------------------------------\n// sampler: takes probabilities and samples integers from them\n\n#define GPT2_EOT 50256\n\nunsigned int random_u32(unsigned long long *state) {\n    // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A\n    *state ^= *state >> 12;\n    *state ^= *state << 25;\n    *state ^= *state >> 27;\n    return (*state * 0x2545F4914F6CDD1Dull) >> 32;\n}\nfloat random_f32(unsigned long long *state) { // random float32 in [0,1)\n    return (random_u32(state) >> 8) / 16777216.0f;\n}\n\nint sample_softmax(const float* logits, int n, float coin) {\n    // sample index from logits (converted to probabilities using softmax)\n    // coin is a random number in [0, 1), usually from random_f32()\n    double norm = 0;\n    for (int i = 0; i < n; i++) {\n        norm += expf(logits[i]);\n    }\n    // instead of dividing all exp(logits), we can just multiply coin.\n    coin *= norm;\n    float cdf = 0.0f;\n    for (int i = 0; i < n; i++) {\n        cdf += expf(logits[i]);\n        if (coin < cdf) {\n            return i;\n        }\n    }\n    return n - 1; // in case of rounding errors\n}\n\n// ----------------------------------------------------------------------------\n// Logger lite, will probably grow/change some over time\n\ntypedef struct {\n    FILE *logfile;\n    int flush_every; // every how many steps to flush the log\n} Logger;\n\nvoid logger_init(Logger *logger, const char *filename) {\n    logger->flush_every = 20;\n    logger->logfile = NULL;\n    if (filename != NULL) { logger->logfile = fopenCheck(filename, \"w\"); }\n}\n\nvoid logger_log_val(Logger *logger, int step, float val_loss) {\n    if (logger->logfile != NULL) {\n        fprintf(logger->logfile, \"s:%d tel:%.4f\\n\", step, val_loss);\n    }\n}\n\nvoid logger_log_train(Logger *logger, int step, float train_loss) {\n    if (logger->logfile != NULL) {\n        fprintf(logger->logfile, \"s:%d trl:%.4f\\n\", step, train_loss);\n        if (step % 10 == 0) { fflush(logger->logfile); }\n    }\n}\n\nvoid logger_free(Logger *logger) {\n    if (logger->logfile != NULL) { fclose(logger->logfile); }\n}\n\n// ----------------------------------------------------------------------------\n// CLI, poor man's argparse\n\nvoid error_usage() {\n    fprintf(stderr, \"Usage:   ./train_gpt2fp32cu [options]\\n\");\n    fprintf(stderr, \"Options:\\n\");\n    fprintf(stderr, \"  -i <string> train data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_train.bin)\\n\");\n    fprintf(stderr, \"  -j <string> val data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_val.bin)\\n\");\n    fprintf(stderr, \"  -o <string> output log file (default = NULL)\\n\");\n    fprintf(stderr, \"  -b <int>    batch size B (default = 4)\\n\");\n    fprintf(stderr, \"  -t <int>    sequence length T (default = 1024)\\n\");\n    fprintf(stderr, \"  -l <float>  learning rate (default = 3e-4f)\\n\");\n    fprintf(stderr, \"  -v <int>    val_loss_every, how often we evaluate val loss (default = 20)\\n\");\n    fprintf(stderr, \"  -m <int>    val_max_steps, up to how many val batches to estimate val loss? (default = 20)\\n\");\n    fprintf(stderr, \"  -s <int>    sample_every, how often we inference the model (default = 20)\\n\");\n    fprintf(stderr, \"  -g <int>    genT, how many steps of inference we do (default = 64)\\n\");\n    exit(EXIT_FAILURE);\n}\n\n// ----------------------------------------------------------------------------\n// main training loop\nint main(int argc, char *argv[]) {\n\n    // read in the (optional) command line arguments\n    const char* train_data_pattern = \"dev/data/tinyshakespeare/tiny_shakespeare_train.bin\";\n    const char* val_data_pattern = \"dev/data/tinyshakespeare/tiny_shakespeare_val.bin\";\n    const char* output_log_file = NULL;\n    int B = 4; // batch size\n    int T = 1024; // sequence length max\n    float learning_rate = 3e-4f;\n    int val_loss_every = 20; // every how many steps do we eval validation loss?\n    int val_max_steps = 20; // how many batches max do we eval for validation loss?\n    int sample_every = 20; // every how many steps to do inference?\n    int genT = 64; // number of steps of inference we will do\n    for (int i = 1; i < argc; i+=2) {\n        if (i + 1 >= argc) { error_usage(); } // must have arg after flag\n        if (argv[i][0] != '-') { error_usage(); } // must start with dash\n        if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)\n        // read in the args\n        if (argv[i][1] == 'i') { train_data_pattern = argv[i+1]; }\n        else if (argv[i][1] == 'j') { val_data_pattern = argv[i+1]; }\n        else if (argv[i][1] == 'o') { output_log_file = argv[i+1]; }\n        else if (argv[i][1] == 'b') { B = atoi(argv[i+1]); }\n        else if (argv[i][1] == 't') { T = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'l') { learning_rate = atof(argv[i+1]); }\n        else if (argv[i][1] == 'v') { val_loss_every = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'm') { val_max_steps = atoi(argv[i+1]); }\n        else if (argv[i][1] == 's') { sample_every = atoi(argv[i+1]); }\n        else if (argv[i][1] == 'g') { genT = atoi(argv[i+1]); }\n        else { error_usage(); }\n    }\n    printf(\"+-----------------------+----------------------------------------------------+\\n\");\n    printf(\"| Parameter             | Value                                              |\\n\");\n    printf(\"+-----------------------+----------------------------------------------------+\\n\");\n    printf(\"| train data pattern    | %-50s |\\n\", train_data_pattern);\n    printf(\"| val data pattern      | %-50s |\\n\", val_data_pattern);\n    printf(\"| output log file       | %-50s |\\n\", output_log_file == NULL ? \"NULL\" : output_log_file);\n    printf(\"| batch size B          | %-50d |\\n\", B);\n    printf(\"| sequence length T     | %-50d |\\n\", T);\n    printf(\"| learning rate         | %-50f |\\n\", learning_rate);\n    printf(\"| val_loss_every        | %-50d |\\n\", val_loss_every);\n    printf(\"| val_max_steps         | %-50d |\\n\", val_max_steps);\n    printf(\"| sample_every          | %-50d |\\n\", sample_every);\n    printf(\"| genT                  | %-50d |\\n\", genT);\n    printf(\"+-----------------------+----------------------------------------------------+\\n\");\n\n    // set up the device\n    int deviceIdx = 0;\n    cudaCheck(cudaSetDevice(deviceIdx));\n    cudaDeviceProp deviceProp;\n    cudaGetDeviceProperties(&deviceProp, deviceIdx);\n    // setup cuBLAS and cuBLASLt\n    cublasCheck(cublasCreate(&cublas_handle));\n    // TF32 precision is equivalent to torch.set_float32_matmul_precision('high')\n    int enable_tf32 = deviceProp.major >= 8 ? 1 : 0;\n    cublas_compute_type = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F;\n    cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH;\n    cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode));\n    printf(\"| device                | %-50s |\\n\", deviceProp.name);\n    printf(\"| TF32                  | %-50s |\\n\", enable_tf32 ? \"enabled\" : \"disabled\");\n    printf(\"+-----------------------+----------------------------------------------------+\\n\");\n\n    // build the GPT-2 model from a checkpoint\n    GPT2 model;\n    gpt2_build_from_checkpoint(&model, \"gpt2_124M.bin\");\n    printf(\"| max_sequence_length T | %-50d |\\n\", model.config.max_seq_len);\n    printf(\"| vocab_size V          | %-50d |\\n\", model.config.vocab_size);\n    printf(\"| padded_vocab_size Vp  | %-50d |\\n\", model.config.padded_vocab_size);\n    printf(\"| num_layers L          | %-50d |\\n\", model.config.num_layers);\n    printf(\"| num_heads NH          | %-50d |\\n\", model.config.num_heads);\n    printf(\"| channels C            | %-50d |\\n\", model.config.channels);\n    printf(\"| num_parameters        | %-50zu |\\n\", model.num_parameters);\n    printf(\"+-----------------------+----------------------------------------------------+\\n\");\n\n    // build DataLoaders for both train and val\n    DataLoader train_loader, val_loader;\n    dataloader_init(&train_loader, train_data_pattern, B, T, 0, 1, 1);\n    dataloader_init(&val_loader, val_data_pattern, B, T, 0, 1, 0);\n    int train_num_batches = train_loader.num_tokens / (B*T); // let's do 1 epoch by default for now\n    int val_num_batches = val_loader.num_tokens / (B*T);\n    if (val_num_batches > val_max_steps) { val_num_batches = val_max_steps; }\n    printf(\"| train_num_batches     | %-50d |\\n\", train_num_batches);\n    printf(\"| val_num_batches       | %-50d |\\n\", val_num_batches);\n    printf(\"+-----------------------+----------------------------------------------------+\\n\");\n\n    // print model parameter allocations from gpt2_build_from_checkpoint down here to not mess up our table above\n    printf(\"allocated %d MiB for model parameters\\n\", (int)round(model.num_parameters * sizeof(float) / (1024 * 1024)));\n\n    // set up the Logger\n    Logger logger;\n    logger_init(&logger, output_log_file);\n\n    // build the Tokenizer\n    Tokenizer tokenizer;\n    tokenizer_init(&tokenizer, \"gpt2_tokenizer.bin\");\n\n    // some memory for generating samples from the model\n    unsigned long long rng_state = 1337;\n    int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int));\n    float* cpu_logits = (float*)mallocCheck(model.config.vocab_size * sizeof(float));\n\n    // train\n    struct timespec start, end;\n    double total_sum_iteration_time_s = 0.0;\n    for (int step = 0; step <= train_num_batches; step++) {\n        int last_step = step == train_num_batches;\n\n        // once in a while estimate the validation loss\n        if (step % val_loss_every == 0 || last_step) {\n            float val_loss = 0.0f;\n            dataloader_reset(&val_loader);\n            for (int i = 0; i < val_num_batches; i++) {\n                dataloader_next_batch(&val_loader);\n                gpt2_forward(&model, val_loader.inputs, val_loader.targets, B, T);\n                val_loss += model.mean_loss;\n            }\n            val_loss /= val_num_batches;\n            printf(\"val loss %f\\n\", val_loss);\n            logger_log_val(&logger, step, val_loss);\n        }\n\n        // once in a while do model inference to print generated text\n        if (step > 0 && step % sample_every == 0 || last_step) {\n            // fill up gen_tokens with the GPT2_EOT, which kicks off the generation\n            for(int i = 0; i < B * T; ++i) {\n                gen_tokens[i] = GPT2_EOT;\n            }\n            // now sample from the model autoregressively\n            printf(\"generating:\\n---\\n\");\n            for (int t = 1; t < genT; t++) {\n                // note that inference is very wasteful here because for each token\n                // we re-calculate the forward pass for all of (B,T) positions from scratch\n                // but the inference here is just for sanity checking anyway\n                // and we can maybe optimize a bit more later, with careful tests\n                gpt2_forward(&model, gen_tokens, NULL, B, T);\n                // furthermore, below we're only using b=0 (i.e. the first row) of all B rows\n                // we're in principle running B \"inference streams\" in parallel here\n                // only using position 0 because it's a bit faster (copy less probs from GPU -> CPU)\n                // get the V-dimensional vector probs[0, t-1, :]\n                float* logits = model.acts.output + (t - 1) * model.config.padded_vocab_size;\n                // move probs back to CPU and sample (note we only move the first vocab_size logits, ignoring the padding)\n                cudaCheck(cudaMemcpy(cpu_logits, logits, model.config.vocab_size * sizeof(float), cudaMemcpyDeviceToHost));\n                float coin = random_f32(&rng_state);\n                int next_token = sample_softmax(cpu_logits, model.config.vocab_size, coin);\n                gen_tokens[t] = next_token;\n                // print the generated token, either using the Tokenizer or a fallback\n                if (tokenizer.init_ok) {\n                    const char* token_str = tokenizer_decode(&tokenizer, next_token);\n                    safe_printf(token_str);\n                } else {\n                    // fall back to printing the token id\n                    printf(\"%d \", next_token);\n                }\n                fflush(stdout);\n            }\n            printf(\"\\n---\\n\");\n        }\n\n        // bit confusing: we want to make sure to eval and sample on 0th iteration\n        // but also after the very last iteration. so we loop for step <= train_num_batches\n        // instead of just < train_num_batches (one extra due to <=), only to do\n        // the validation/sampling one last time, and then we break right here as we're done.\n        if (last_step) { break; }\n\n        // do a training step\n        clock_gettime(CLOCK_MONOTONIC, &start);\n        dataloader_next_batch(&train_loader);\n        gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T);\n        gpt2_zero_grad(&model);\n        gpt2_backward(&model);\n        gpt2_update(&model, learning_rate, 0.9f, 0.999f, 1e-8f, 0.0f, step+1);\n        cudaCheck(cudaDeviceSynchronize()); // finish all CUDA work to get correct precise timings\n        clock_gettime(CLOCK_MONOTONIC, &end);\n        double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;\n        total_sum_iteration_time_s += time_elapsed_s;\n        int tokens_per_second = (B * T) / time_elapsed_s;\n        printf(\"step %4d/%d: train loss %f (%f ms, %d tok/s)\\n\", step + 1, train_num_batches, model.mean_loss, time_elapsed_s * 1000, tokens_per_second);\n        logger_log_train(&logger, step, model.mean_loss);\n    }\n    // add a total average, for optimizations that are only mild improvements\n    printf(\"total average iteration time: %f ms\\n\", total_sum_iteration_time_s / train_num_batches * 1000);\n\n    // free\n    dataloader_free(&train_loader);\n    dataloader_free(&val_loader);\n    tokenizer_free(&tokenizer);\n    gpt2_free(&model);\n    free(cpu_logits);\n    free(gen_tokens);\n    cublasCheck(cublasDestroy(cublas_handle));\n    logger_free(&logger);\n\n    return 0;\n}\n#endif"
  },
  {
    "path": "train_llama3.py",
    "content": "\"\"\"\nReference code for LLaMA-3.1 training and inference.\nWill save the model weights into files, to be read from C as initialization.\n\nThis code differs from GPT-2 very slightly, there are three main differences:\n1) RoPE: LLaMA uses a different positional encoding scheme called Relative Positional Encoding (RoPE).\n2) GQA: Grouped Query Attention (GQA) is used to reduce the number of attention heads.\n3) SwiGLU: Swish-Gated Linear Unit (SwiGLU) is used as the activation function in the MLP.\n\nReferences:\n# 1) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/tokenizer.py\n# 2) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py\n# 3) https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/generation.py\n\nExample launches to only benchmark the speed of bfloat16 compiled GPU training:\nTODO: add the actual commands\n\"\"\"\n\nimport argparse\nimport os\nimport math\nimport glob\nimport inspect\nfrom contextlib import nullcontext\nfrom dataclasses import dataclass\nfrom pathlib import Path\nimport time\nfrom typing import (\n    AbstractSet,\n    Collection,\n    Dict,\n    Iterator,\n    List,\n    Literal,\n    Optional,\n    Sequence,\n    Tuple,\n    Union,\n    cast,\n)\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport torch._inductor.config as config\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom torch.distributed import init_process_group, destroy_process_group\nfrom torch.distributed.optim import ZeroRedundancyOptimizer\nimport torch.distributed as dist\n\nimport tiktoken\nfrom tiktoken.load import load_tiktoken_bpe\n\n# -----------------------------------------------------------------------------\n# PyTorch nn.Module definitions for the LLaMA 3.x model\n\n# Used in Grouped Query Attention (GQA), broadcasts the key and value tensors\ndef repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"torch.repeat_interleave(x, dim=2, repeats=n_rep)\"\"\"\n    bs, slen, n_kv_heads, head_dim = x.shape\n    if n_rep == 1:\n        return x\n    return (\n        x[:, :, :, None, :]\n        .expand(bs, slen, n_kv_heads, n_rep, head_dim)\n        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)\n    )\n\n# -----------------------------------------------------------------------------\n# RoPE related\n\ndef reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):\n    ndim = x.ndim\n    assert 0 <= 1 < ndim\n    assert freqs_cis.shape == (x.shape[1], x.shape[-1])\n    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]\n    return freqs_cis.view(*shape)\n\ndef apply_scaling(freqs: torch.Tensor):\n    # Values obtained from grid search\n    scale_factor = 8\n    low_freq_factor = 1\n    high_freq_factor = 4\n    old_context_len = 8192  # original llama3 length\n\n    low_freq_wavelen = old_context_len / low_freq_factor\n    high_freq_wavelen = old_context_len / high_freq_factor\n    new_freqs = []\n    for freq in freqs:\n        wavelen = 2 * math.pi / freq\n        if wavelen < high_freq_wavelen:\n            new_freqs.append(freq)\n        elif wavelen > low_freq_wavelen:\n            new_freqs.append(freq / scale_factor)\n        else:\n            assert low_freq_wavelen != high_freq_wavelen\n            smooth = (old_context_len / wavelen - low_freq_factor) / (\n                high_freq_factor - low_freq_factor\n            )\n            new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)\n    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)\n\ndef apply_rotary_emb(\n    xq: torch.Tensor,\n    xk: torch.Tensor,\n    freqs_cis: torch.Tensor,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))\n    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))\n    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)\n    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)\n    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)\n    return xq_out.type_as(xq), xk_out.type_as(xk)\n\ndef precompute_freqs_cis(\n    dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False\n):\n    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))\n    t = torch.arange(end, device=freqs.device, dtype=torch.float32)\n    if use_scaled:\n        freqs = apply_scaling(freqs)\n    freqs = torch.outer(t, freqs)\n    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64\n    return freqs_cis\n\n# -----------------------------------------------------------------------------\n# LLaMA building blocks\n\n# LLaMA reference code explicitly implemented RMSNorm so we copy pasted it\n# (https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py)\n# we could also use nn.RMSNorm, it has slightly different numeric properties, but equivalent\nclass RMSNorm(torch.nn.Module):\n    def __init__(self, dim: int, eps: float = 1e-6):\n        super().__init__()\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(dim))\n\n    def _norm(self, x):\n        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n    def forward(self, x):\n        output = self._norm(x.float()).type_as(x)\n        return output * self.weight\n\nclass CausalSelfAttention(nn.Module):\n\n    def __init__(self, config):\n        super().__init__()\n        assert config.n_embd % config.n_head == 0\n\n        self.n_head = config.n_head\n        self.n_kv_head = config.n_kv_head\n        self.n_rep = self.n_head // self.n_kv_head\n        self.hd = config.n_embd // config.n_head\n        self.use_kv = config.use_kv\n        self.flash = config.flash\n\n        self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.hd, bias=False)  # key, query, value projections\n        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)  # output projection\n\n        # static KV cache - we could alternatively allocate it outside of the model and just pass it in when needed\n        if self.use_kv:\n            self.cache_k = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd))\n            self.cache_v = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd))\n\n    def forward(self, x, freqs_cis=None, start_pos=None, mask=None):\n        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)\n        # calculate query, key, values for all heads in batch and move head forward to be the batch dim\n        qkv = self.c_attn(x)\n        q, k, v = qkv.split([self.n_head * self.hd, self.n_kv_head * self.hd, self.n_kv_head * self.hd], dim=-1)\n        q, k, v = map(lambda t: t.view(B, T, -1, self.hd), (q, k, v))  # (B, T, NH, HD)\n\n        q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)  # rotate QK (rope)  <-- 1. difference compared to GPT-2\n\n        if self.use_kv and not self.training and start_pos >= 0:  # use kv-caching during inference\n            self.cache_k[:B, start_pos : start_pos + T] = k\n            self.cache_v[:B, start_pos : start_pos + T] = v\n            k = self.cache_k[:B, : start_pos + T]\n            v = self.cache_v[:B, : start_pos + T]\n\n        k = repeat_kv(k, self.n_rep)  # GQA <-- 2. difference compared to GPT-2\n        v = repeat_kv(v, self.n_rep)\n\n        q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))  # (B, NH, T, HD)\n\n        if self.flash:\n            # flashattention\n            # if T == 1 no need to mask, otherwise the function complains\n            # scaled_dot_product_attention expects a mask where value of True indicates that the element should take part in attention\n            # our mask is the opposite, so we need to invert it\n            y = F.scaled_dot_product_attention(q, k, v, mask == 0 if T > 1 else None)\n        else:\n            # manual implementation of attention\n            # this materializes the large (T,T) matrix for all the queries and keys\n            scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.hd))\n            if mask is not None:\n                scores.masked_fill_(mask, torch.finfo(scores.dtype).min)\n            att = F.softmax(scores.float(), dim=-1).type_as(q)\n            y = att @ v # (B, NH, T, T) x (B, NH, T, HD) -> (B, NH, T, HD)\n        y = y.transpose(1, 2).contiguous().view(B, T, C)\n        y = self.c_proj(y)\n        return y\n\nclass MLP(nn.Module):\n\n    def __init__(self, config):\n        super().__init__()\n        hidden_dim = 4 * config.n_embd\n        hidden_dim = int(2 * hidden_dim / 3)\n        # custom dim factor multiplier\n        if config.ffn_dim_multiplier is not None:\n            hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)\n        hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)\n        self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False)\n        self.c_fc2 = nn.Linear(config.n_embd, hidden_dim, bias=False)\n        self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)\n\n    def forward(self, x):\n        # SwiGLU self.c_proj(F.silu(self.c_fc2(x)) * self.c_fc(x))  <-- 3. difference compared to GPT-2\n        x1 = self.c_fc(x)\n        x2 = self.c_fc2(x)\n        x2 = F.silu(x2)\n        x = x1 * x2\n        x = self.c_proj(x)\n        return x\n\nclass Block(nn.Module):\n\n    def __init__(self, config):\n        super().__init__()\n        self.ln_1 = RMSNorm(config.n_embd, config.norm_eps)\n        self.attn = CausalSelfAttention(config)\n        self.ln_2 = RMSNorm(config.n_embd, config.norm_eps)\n        self.mlp = MLP(config)\n\n    def forward(self, x, freqs_cis=None, start_pos=None, mask=None):\n        x = x + self.attn(self.ln_1(x), freqs_cis, start_pos, mask)\n        x = x + self.mlp(self.ln_2(x))\n        return x\n\n# -----------------------------------------------------------------------------\n# The main LLaMA 3.1 model\n\n@dataclass\nclass LlamaConfig:\n    version: str = \"3.1\"\n    block_size: int = 8192\n    vocab_size: int = 128256\n    n_layer: int = 32\n    n_head: int = 32\n    n_kv_head: int = 8\n    n_embd: int = 4096\n    ffn_dim_multiplier: float = 1.3\n    multiple_of: int = 1024\n    norm_eps: float = 1e-5\n    rope_theta: float = 500000.0\n    use_scaled_rope: bool = True\n    max_gen_batch_size: int = 4\n    use_kv: bool = True\n    flash: bool = False  # use flashattention?\n\n    def __init__(self, **kwargs):\n        for k, v in kwargs.items():\n            if hasattr(self, k):\n                setattr(self, k, v)\n        assert self.n_kv_head <= self.n_head\n        assert self.n_head % self.n_kv_head == 0\n        assert self.n_embd % self.n_head == 0\n\nclass LLaMA(nn.Module):\n\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n\n        self.transformer = nn.ModuleDict(dict(\n            wte = nn.Embedding(config.vocab_size, config.n_embd),\n            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),\n            ln_f = RMSNorm(config.n_embd, config.norm_eps),\n        ))\n        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n\n        # init all weights, use a torch rng object to be very careful\n        self.init_rng = torch.Generator()\n        self.init_rng.manual_seed(42)\n\n        self.freqs_cis = precompute_freqs_cis(\n            config.n_embd // config.n_head,\n            config.block_size * 2,\n            config.rope_theta,\n            config.use_scaled_rope,\n        )\n\n    def forward(self, idx, targets=None, return_logits=True, start_pos=0):\n        _, t = idx.size()\n        assert t <= self.config.block_size, f\"Cannot forward sequence of length {t}, block size is only {self.config.block_size}\"\n\n        # forward the LLaMA model itself\n        x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)\n        freqs_cis = self.freqs_cis[start_pos:start_pos+t]\n\n        mask = torch.triu(torch.ones((t, t), device=next(self.parameters()).device, dtype=torch.bool), diagonal=1)\n\n        for i, block in enumerate(self.transformer.h):\n            x = block(x, freqs_cis, start_pos, mask)\n        x = self.transformer.ln_f(x)\n\n        if targets is not None:\n            # if we are given some desired targets also calculate the loss\n            logits = self.lm_head(x).float()\n            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)\n        else:\n            # inference-time mini-optimization: only forward the lm_head on the very last position\n            logits = self.lm_head(x[:, [-1], :]).float() # note: using list [-1] to preserve the time dim\n            loss = None\n\n        # there are performance reasons why not returning logits is prudent, if not needed\n        if not return_logits:\n            logits = None\n\n        return logits, loss\n\n    @staticmethod\n    def adapt_llama_state_dict_keys(checkpoint, config: LlamaConfig):\n        # Modify key names from Meta's LLaMA to our LLaMA\n        # our key names are derived from GPT-2's key names\n        checkpoint['transformer.wte.weight'] = checkpoint.pop('tok_embeddings.weight')\n\n        for i in range(config.n_layer):\n            for name in ['attention_norm', 'ffn_norm']:\n                old_key = f'layers.{i}.{name}.weight'  # e.g. layers.x.attention_norm.weight -> transformer.h.x.ln_1.weight\n                new_key = f'transformer.h.{i}.ln_{1 if name == \"attention_norm\" else 2}.weight'\n                checkpoint[new_key] = checkpoint.pop(old_key)\n\n        for i in range(config.n_layer):\n            for name in ['attention.wq', 'attention.wk', 'attention.wv']:\n                old_key = f'layers.{i}.{name}.weight'\n                new_key = f'transformer.h.{i}.attn.c_attn.weight'\n                if name == 'attention.wq':\n                    checkpoint[new_key] = checkpoint.pop(old_key)\n                else:  # merge 3 weights into transformer.h.x.attn.c_attn.weight\n                    checkpoint[new_key] = torch.cat((checkpoint[new_key], checkpoint.pop(old_key)), dim=0)\n            old_key = f'layers.{i}.attention.wo.weight'\n            new_key = f'transformer.h.{i}.attn.c_proj.weight'\n            checkpoint[new_key] = checkpoint.pop(old_key)\n\n        ffn_map = {'w1': 'c_fc2', 'w2': 'c_proj', 'w3': 'c_fc'}\n        for i in range(config.n_layer):\n            for name in ['feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3']:\n                old_key = f'layers.{i}.{name}.weight'\n                new_key = f'transformer.h.{i}.mlp.{ffn_map[name.split(\".\")[-1]]}.weight'\n                checkpoint[new_key] = checkpoint.pop(old_key)\n\n        checkpoint['transformer.ln_f.weight'] = checkpoint.pop('norm.weight')\n        checkpoint['lm_head.weight'] = checkpoint.pop('output.weight')\n\n        return checkpoint\n\n    @staticmethod\n    def adapt_llama_state_dict_keys_hf(checkpoint, config: LlamaConfig):\n        # Modify key names from HuggingFace's LLaMA to our LLaMA\n        # our key names are derived from GPT-2's key names\n        checkpoint['transformer.wte.weight'] = checkpoint.pop('model.embed_tokens.weight')\n\n        # We need to unpermute K and V because HF script permuted the original Meta-LLaMA weights\n        # see: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py\n        def unpermute(w, n_heads, dim1, dim2):\n            return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)\n\n        for i in range(config.n_layer):\n            for name in ['input_layernorm', 'post_attention_layernorm']:\n                old_key = f'model.layers.{i}.{name}.weight'  # e.g. layers.x.attention_norm.weight -> transformer.h.x.ln_1.weight\n                new_key = f'transformer.h.{i}.ln_{1 if name == \"input_layernorm\" else 2}.weight'\n                checkpoint[new_key] = checkpoint.pop(old_key)\n\n        for i in range(config.n_layer):\n            for name in ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj']:\n                old_key = f'model.layers.{i}.{name}.weight'\n                new_key = f'transformer.h.{i}.attn.c_attn.weight'\n                if name == 'self_attn.q_proj':\n                    checkpoint[new_key] = unpermute(checkpoint.pop(old_key), config.n_head, config.n_embd, config.n_embd)\n                else:  # merge 3 weights into transformer.h.x.attn.c_attn.weight\n                    tensor = checkpoint.pop(old_key)\n                    if name == 'self_attn.k_proj':\n                        tensor = unpermute(tensor, config.n_kv_head, config.n_kv_head * (config.n_embd // config.n_head), config.n_embd)\n                    checkpoint[new_key] = torch.cat((checkpoint[new_key], tensor), dim=0)\n            old_key = f'model.layers.{i}.self_attn.o_proj.weight'\n            new_key = f'transformer.h.{i}.attn.c_proj.weight'\n            checkpoint[new_key] = checkpoint.pop(old_key)\n\n        ffn_map = {'gate_proj': 'c_fc2', 'down_proj': 'c_proj', 'up_proj': 'c_fc'}\n        for i in range(config.n_layer):\n            for name in ['gate_proj', 'down_proj', 'up_proj']:\n                old_key = f'model.layers.{i}.mlp.{name}.weight'\n                new_key = f'transformer.h.{i}.mlp.{ffn_map[name]}.weight'\n                checkpoint[new_key] = checkpoint.pop(old_key)\n\n        checkpoint['transformer.ln_f.weight'] = checkpoint.pop('model.norm.weight')\n\n        return checkpoint\n\n    @classmethod\n    def from_pretrained_llama3_hf(cls, model_id):\n        \"\"\"Loads pretrained LLaMA model weights from HuggingFace\"\"\"\n        from transformers import AutoModelForCausalLM, AutoTokenizer\n        assert model_id == \"meta-llama/Meta-Llama-3.1-8B\", \"Only the 8B-base model is supported for now\"\n        model_args = LlamaConfig()\n\n        model = AutoModelForCausalLM.from_pretrained(model_id)\n        checkpoint = LLaMA.adapt_llama_state_dict_keys_hf(model.state_dict(), model_args)\n\n        original_default_type = torch.get_default_dtype()  # save the default type\n        torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)  # much faster loading\n        model = LLaMA(model_args)\n        model.load_state_dict(checkpoint, strict=False)\n        torch.set_default_tensor_type(torch.tensor([], dtype=original_default_type, device=\"cpu\").type())  # restore default type\n\n        tokenizer = AutoTokenizer.from_pretrained(model_id)\n        tokenizer.pad_id = 128004  # this is the pad token id for LLaMA 3.1 base, we need to set this explicitly as our generate func expects it\n        tokenizer.stop_tokens = [tokenizer.eos_token_id]\n        model.tokenizer = tokenizer\n        return model\n\n    @classmethod\n    def from_pretrained_llama3_meta(cls, ckpt_dir, tokenizer_path):\n        \"\"\"Loads pretrained LLaMA model weights from a checkpoint directory\"\"\"\n        model_args = LlamaConfig()\n\n        ckpt_path = sorted(Path(ckpt_dir).glob(\"*.pth\"))[0]\n        checkpoint = torch.load(ckpt_path, map_location=\"cpu\", weights_only=True)\n        checkpoint = LLaMA.adapt_llama_state_dict_keys(checkpoint, model_args)\n\n        original_default_type = torch.get_default_dtype()  # save the default type\n        torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)  # much faster loading\n        model = LLaMA(model_args)\n        model.load_state_dict(checkpoint, strict=False)\n        torch.set_default_tensor_type(torch.tensor([], dtype=original_default_type, device=\"cpu\").type())  # restore default type\n\n        tokenizer = Tokenizer(model_path=tokenizer_path)\n        model.tokenizer = tokenizer\n        return model\n\n    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, zero_stage):\n        # start with all of the candidate parameters\n        param_dict = {pn: p for pn, p in self.named_parameters()}\n        # filter out those that do not require grad\n        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}\n        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.\n        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.\n        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]\n        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]\n        optim_groups = [\n            {'params': decay_params, 'weight_decay': weight_decay},\n            {'params': nodecay_params, 'weight_decay': 0.0}\n        ]\n        num_decay_params = sum(p.numel() for p in decay_params)\n        num_nodecay_params = sum(p.numel() for p in nodecay_params)\n        print0(f\"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters\")\n        print0(f\"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters\")\n        # Create AdamW optimizer and use the fused version if it is available\n        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters\n        use_fused = fused_available and device_type == 'cuda'\n        print0(f\"using fused AdamW: {use_fused}\")\n        if zero_stage == 1:\n            print0(\"using ZeroRedundancyOptimizer\")\n            optimizer = ZeroRedundancyOptimizer(**optim_groups[0], optimizer_class=torch.optim.AdamW,\n                                                lr=learning_rate, betas=betas, fused=use_fused)\n            optimizer.add_param_group(optim_groups[1])\n        else:\n            print0(\"using regular AdamW\")\n            optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused)\n        return optimizer\n\n    @torch.inference_mode()\n    def generate(\n        self,\n        prompt_tokens: List[List[int]],\n        max_gen_len: int,\n        temperature: float = 0.6,\n        top_p: float = 0.9,\n        echo: bool = False,\n    ) -> Tuple[List[List[int]], Optional[List[List[float]]]]:\n        \"\"\"\n        Generate text sequences based on provided prompts using the language generation model.\n\n        Args:\n            prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.\n            max_gen_len (int): Maximum length of the generated text sequence.\n            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.\n            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.\n            echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.\n\n        Returns:\n            Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences.\n\n        Note:\n            This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.\n\n        \"\"\"\n        bsz = len(prompt_tokens)\n        assert bsz <= self.config.max_gen_batch_size, f\"Batch size {bsz} exceeds the maximum generation batch size {self.config.max_gen_batch_size}\"\n        device = next(self.parameters()).device\n\n        min_prompt_len = min(len(t) for t in prompt_tokens)\n        max_prompt_len = max(len(t) for t in prompt_tokens)\n        assert max_prompt_len <= self.config.block_size, f\"Prompt length {max_prompt_len} exceeds the maximum block size {self.config.block_size}\"\n        total_len = min(self.config.block_size, max_gen_len + max_prompt_len)\n\n        pad_id = self.tokenizer.pad_id\n        tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device)\n        for idx, t in enumerate(prompt_tokens):\n            tokens[idx, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)\n\n        prev_pos = 0\n        eos_reached = torch.tensor([False] * bsz, device=device)\n        input_text_mask = tokens != pad_id\n\n        if min_prompt_len == total_len:\n            logits, _ = self.forward(tokens, start_pos=prev_pos)\n\n        stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)).to(device)\n\n        for cur_pos in range(min_prompt_len, total_len):\n            logits, _ = self.forward(tokens[:, prev_pos:cur_pos], start_pos=prev_pos)\n            if temperature > 0:\n                probs = torch.softmax(logits[:, -1] / temperature, dim=-1)\n                next_token = sample_top_p(probs, top_p)\n            else:\n                next_token = torch.argmax(logits[:, -1], dim=-1)\n\n            next_token = next_token.reshape(-1)\n            # only replace token if prompt has already been generated\n            next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)\n            tokens[:, cur_pos] = next_token\n            eos_reached |= ~input_text_mask[:, cur_pos] & torch.isin(next_token, stop_tokens)\n            prev_pos = cur_pos\n            if all(eos_reached):\n                break\n\n        out_tokens = []\n        for i, toks in enumerate(tokens.tolist()):\n            # cut to max gen len\n            start = 0 if echo else len(prompt_tokens[i])\n            toks = toks[start : len(prompt_tokens[i]) + max_gen_len]\n            # cut to after eos tok if any\n            for stop_token in self.tokenizer.stop_tokens:\n                try:\n                    eos_idx = toks.index(stop_token)\n                    toks = toks[:eos_idx]\n                except ValueError:\n                    pass\n            out_tokens.append(toks)\n        return out_tokens\n\n# -----------------------------------------------------------------------------\n# sampling utils\n\ndef sample_top_p(probs, p):\n    \"\"\"\n    Perform top-p (nucleus) sampling on a probability distribution.\n\n    Args:\n        probs (torch.Tensor): Probability distribution tensor.\n        p (float): Probability threshold for top-p sampling.\n\n    Returns:\n        torch.Tensor: Sampled token indices.\n\n    Note:\n        Top-p sampling selects the smallest set of tokens whose cumulative probability mass\n        exceeds the threshold p. The distribution is renormalized based on the selected tokens.\n    \"\"\"\n    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)\n    probs_sum = torch.cumsum(probs_sort, dim=-1)\n    mask = probs_sum - probs_sort > p\n    probs_sort[mask] = 0.0\n    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))\n    next_token = torch.multinomial(probs_sort, num_samples=1)\n    next_token = torch.gather(probs_idx, -1, next_token)\n    return next_token\n\n# -----------------------------------------------------------------------------\n# Llama 3.1 Tokenizer\n\n# The tiktoken tokenizer can handle <=400k chars without\n# pyo3_runtime.PanicException.\nTIKTOKEN_MAX_ENCODE_CHARS = 400_000\n\n# https://github.com/openai/tiktoken/issues/195\n# Here we iterate over subsequences and split if we exceed the limit\n# of max consecutive non-whitespace or whitespace characters.\nMAX_NO_WHITESPACES_CHARS = 25_000\n\n\nclass Tokenizer:\n    \"\"\"\n    Tokenizing and encoding/decoding text using the Tiktoken tokenizer.\n    \"\"\"\n\n    special_tokens: Dict[str, int]\n\n    num_reserved_special_tokens = 256\n\n    pat_str = r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\"  # noqa: E501\n\n    def __init__(self, model_path: str):\n        \"\"\"\n        Initializes the Tokenizer with a Tiktoken model.\n\n        Args:\n            model_path (str): The path to the Tiktoken model file.\n        \"\"\"\n        assert os.path.isfile(model_path), model_path\n\n        mergeable_ranks = load_tiktoken_bpe(model_path)\n        num_base_tokens = len(mergeable_ranks)\n        special_tokens = [\n            \"<|begin_of_text|>\",\n            \"<|end_of_text|>\",\n            \"<|reserved_special_token_0|>\",\n            \"<|reserved_special_token_1|>\",\n            \"<|finetune_right_pad_id|>\",\n            \"<|step_id|>\",\n            \"<|start_header_id|>\",\n            \"<|end_header_id|>\",\n            \"<|eom_id|>\",  # end of message\n            \"<|eot_id|>\",  # end of turn\n            \"<|python_tag|>\",\n        ]\n        reserved_tokens = [\n            f\"<|reserved_special_token_{2 + i}|>\"\n            for i in range(self.num_reserved_special_tokens - len(special_tokens))\n        ]\n        special_tokens = special_tokens + reserved_tokens\n\n        self.special_tokens = {\n            token: num_base_tokens + i for i, token in enumerate(special_tokens)\n        }\n        self.model = tiktoken.Encoding(\n            name=Path(model_path).name,\n            pat_str=self.pat_str,\n            mergeable_ranks=mergeable_ranks,\n            special_tokens=self.special_tokens,\n        )\n\n        self.n_words: int = num_base_tokens + len(special_tokens)\n        # BOS / EOS token IDs\n        self.bos_id: int = self.special_tokens[\"<|begin_of_text|>\"]\n        self.eos_id: int = self.special_tokens[\"<|end_of_text|>\"]\n        self.eot_id: int = self.special_tokens[\"<|eot_id|>\"]\n        self.eom_id: int = self.special_tokens[\"<|eom_id|>\"]\n        self.python_tag_id = self.special_tokens[\"<|python_tag|>\"]\n        self.pad_id: int = self.special_tokens[\"<|finetune_right_pad_id|>\"]\n        # hardcoded stop tokens for the base model\n        self.stop_tokens = [\n            self.special_tokens[\"<|begin_of_text|>\"],\n            self.special_tokens[\"<|end_of_text|>\"],\n        ]\n\n    def encode(\n        self,\n        s: str,\n        *,\n        bos: bool,\n        eos: bool,\n        allowed_special: Optional[Union[Literal[\"all\"], AbstractSet[str]]] = None,\n        disallowed_special: Union[Literal[\"all\"], Collection[str]] = (),\n    ) -> List[int]:\n        \"\"\"\n        Encodes a string into a list of token IDs.\n\n        Args:\n            s (str): The input string to be encoded.\n            bos (bool): Whether to prepend the beginning-of-sequence token.\n            eos (bool): Whether to append the end-of-sequence token.\n            allowed_tokens (\"all\"|set[str]): allowed special tokens in string\n            disallowed_tokens (\"all\"|set[str]): special tokens that raise an error when in string\n\n        Returns:\n            list[int]: A list of token IDs.\n\n        By default, setting disallowed_special=() encodes a string by ignoring\n        special tokens. Specifically:\n        - Setting `disallowed_special` to () will cause all text corresponding\n          to special tokens to be encoded as natural text (insteading of raising\n          an error).\n        - Setting `allowed_special` to \"all\" will treat all text corresponding\n          to special tokens to be encoded as special tokens.\n        \"\"\"\n        if allowed_special is None:\n            allowed_special = set()\n        assert type(s) is str\n\n        substrs = (\n            substr\n            for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)\n            for substr in self._split_whitespaces_or_nonwhitespaces(\n                s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS\n            )\n        )\n        t: List[int] = []\n        for substr in substrs:\n            t.extend(\n                self.model.encode(\n                    substr,\n                    allowed_special=allowed_special,\n                    disallowed_special=disallowed_special,\n                )\n            )\n        if bos:\n            t.insert(0, self.bos_id)\n        if eos:\n            t.append(self.eos_id)\n        return t\n\n    def decode(self, t: Sequence[int]) -> str:\n        # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.\n        return self.model.decode(cast(List[int], t))\n\n    @staticmethod\n    def _split_whitespaces_or_nonwhitespaces(\n        s: str, max_consecutive_slice_len: int\n    ) -> Iterator[str]:\n        \"\"\"\n        Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`\n        consecutive whitespaces or consecutive non-whitespaces.\n        \"\"\"\n        current_slice_len = 0\n        current_slice_is_space = s[0].isspace() if len(s) > 0 else False\n        slice_start = 0\n\n        for i in range(len(s)):\n            is_now_space = s[i].isspace()\n\n            if current_slice_is_space ^ is_now_space:\n                current_slice_len = 1\n                current_slice_is_space = is_now_space\n            else:\n                current_slice_len += 1\n                if current_slice_len > max_consecutive_slice_len:\n                    yield s[slice_start:i]\n                    slice_start = i\n                    current_slice_len = 1\n        yield s[slice_start:]\n\n# -----------------------------------------------------------------------------\n# Our own simple Distributed Data Loader\n\ndef _peek_data_shard(filename):\n    # only reads the header, returns header data\n    with open(filename, \"rb\") as f:\n        # first read the header, which is 256 int32 integers (4 bytes each)\n        header = np.frombuffer(f.read(256*4), dtype=np.int32)\n    if header[0] != 20240801:\n        print(\"ERROR: magic number mismatch in the data .bin file!\")\n        exit(1)\n    assert header[1] == 7, \"unsupported version\"\n    ntok = header[2] # number of tokens (claimed)\n    return ntok # for now just return the number of tokens\n\ndef _load_data_shard(filename):\n    with open(filename, \"rb\") as f:\n        # first read the header, which is 256 int32 integers (4 bytes each)\n        header = np.frombuffer(f.read(256*4), dtype=np.int32)\n        assert header[0] == 20240801, \"magic number mismatch in the data .bin file\"\n        assert header[1] == 7, \"unsupported version\"\n        ntok = header[2] # number of tokens (claimed)\n        # the rest of it are tokens, stored as uint16\n        tokens = np.frombuffer(f.read(), dtype=np.uint32)\n    assert len(tokens) == ntok, \"number of tokens read does not match header?\"\n    return tokens\n\nclass DistributedShardedDataLoader:\n    \"\"\"\n    This DataLoader is both:\n    - distributed (works correctly in case of multiple processes in DDP)\n    - sharded (supports datasets that are broken up into multiple data shards)\n    It is not *permuted*, meaning that it itearates over the data in the order\n    of the dataset on disk, so the user should make sure to shuffle their examples\n    during the creation of their data shards for best performance.\n    \"\"\"\n    def __init__(self, filename_pattern, B, T, process_rank, num_processes):\n        self.process_rank = process_rank\n        self.num_processes = num_processes\n        self.B = B\n        self.T = T\n\n        # glob files that match the pattern\n        self.files = sorted(glob.glob(filename_pattern))\n        assert len(self.files) > 0, f\"did not find any files that match the pattern {filename_pattern}\"\n\n        # load and validate all data shards, count number of tokens in total\n        ntok_total = 0\n        for fname in self.files:\n            shard_ntok = _peek_data_shard(fname)\n            assert shard_ntok >= num_processes * B * T + 1\n            ntok_total += shard_ntok\n        self.ntok_total = ntok_total\n        print0(f\"DataLoader: total number of tokens: {ntok_total:,} across {len(self.files)} files\")\n\n        # kick things off\n        self.current_shard = None\n        self.reset()\n\n    def reset(self):\n        # we're being a bit clever here: if we already had shard 0 loaded,\n        # then don't do the work to reload it, just reset the pointer\n        if self.current_shard != 0:\n            self.current_shard = 0\n            self.tokens = _load_data_shard(self.files[self.current_shard])\n        self.current_position = self.process_rank * self.B * self.T\n\n    def advance(self): # advance to next data shard\n        self.current_shard = (self.current_shard + 1) % len(self.files)\n        self.current_position = self.process_rank * self.B * self.T\n        self.tokens = _load_data_shard(self.files[self.current_shard])\n\n    def next_batch(self):\n        B = self.B\n        T = self.T\n        buf = self.tokens[self.current_position : self.current_position+B*T+1]\n        buf = torch.tensor(buf, dtype=torch.long)\n        x = (buf[:-1]).view(B, T) # inputs\n        y = (buf[1:]).view(B, T) # targets\n        # advance the start pointer in current shard\n        self.current_position += B * T * self.num_processes\n        # if loading the next batch would be out of bounds advance the shard\n        if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):\n            self.advance()\n        return x, y\n\n# -----------------------------------------------------------------------------\n# Python -> C bridge utilities for saving params/grads/activations to .bin files\n\ndef write_fp32(tensor, file):\n    t = tensor.detach().cpu().to(torch.float32)\n    b = t.numpy().tobytes()\n    file.write(b)\n\ndef write_bf16(tensor, file):\n    t = tensor.detach().cpu().to(torch.bfloat16)\n    # numpy doesn't have bf16 datatype so we have to trick it\n    t = t.view(torch.int16) # trick: reinterpret as int16\n    b = t.numpy().tobytes()\n    file.write(b)\n\ndef write_tensors(model_tensors, L, file, dtype):\n    # writes LLaMA 3 model's weights to a binary file\n    assert dtype in {\"float32\", \"bfloat16\"}\n    write_fun = write_fp32 if dtype == \"float32\" else write_bf16\n    write_fun(model_tensors[\"transformer.wte.weight\"], file) # (V, C)\n    for i in range(L): # (L, C)\n        write_fun(model_tensors[f\"transformer.h.{i}.ln_1.weight\"], file)\n    for i in range(L): # (L, 3C, C)\n        write_fun(model_tensors[f\"transformer.h.{i}.attn.c_attn.weight\"], file)\n    for i in range(L): # (L, C, C)\n        write_fun(model_tensors[f\"transformer.h.{i}.attn.c_proj.weight\"], file)\n    for i in range(L): # (L, C)\n        write_fun(model_tensors[f\"transformer.h.{i}.ln_2.weight\"], file)\n    for i in range(L): # (L, 4C, C)\n        write_fun(model_tensors[f\"transformer.h.{i}.mlp.c_fc.weight\"], file)\n    for i in range(L): # (L, 4C, C)\n        write_fun(model_tensors[f\"transformer.h.{i}.mlp.c_fc2.weight\"], file)\n    for i in range(L): # (L, C, 4C)\n        write_fun(model_tensors[f\"transformer.h.{i}.mlp.c_proj.weight\"], file)\n    write_fun(model_tensors[\"transformer.ln_f.weight\"], file) # (C, )\n    write_fun(model_tensors[\"lm_head.weight\"], file) # (V, C)\n\ndef write_model(model, filename, dtype):\n    # everything we need to instantiate the model\n    # 1) header is: version int, LLaMAConfig ints, padding to 1024 bytes\n    assert dtype in {\"float32\", \"bfloat16\"}\n    version = {\n        \"float32\": 3, # 3: all tensors are fp32\n        \"bfloat16\": 5, # 5: all tensors are bf16\n    }[dtype]\n    header = torch.zeros(256, dtype=torch.int32)\n    header[0] = 20240803 # magic\n    header[1] = version # checkpoint version\n    header[2] = model.config.block_size\n    header[3] = model.config.vocab_size\n    header[4] = model.config.n_layer\n    header[5] = model.config.n_head\n    header[6] = model.config.n_kv_head\n    header[7] = model.config.n_embd\n    header[8] = model.config.ffn_dim_multiplier\n    header[9] = model.config.multiple_of\n    header[10] = model.config.norm_eps\n    header[11] = model.config.rope_theta\n    header[12] = model.config.use_scaled_rope\n    header[13] = model.config.max_gen_batch_size\n    header[14] = int(model.config.version.split('.')[0]) # major version\n    header[15] = int(model.config.version.split('.')[1]) # minor version\n    # 2) the parameters follow the header\n    params = {name: param.cpu() for name, param in model.named_parameters()}\n    # now write to file\n    with open(filename, \"wb\") as file:\n        file.write(header.numpy().tobytes()) # header\n        write_tensors(params, model.config.n_layer, file, dtype) # params\n    print(f\"wrote {filename}\")\n\ndef write_state(model, x, y, logits, loss, filename):\n    # the state is used for debugging.\n    # it contains information about the input, logits, loss, and the parameter gradients\n    # this can be used for checking the computation correctness in C\n    header = torch.zeros(256, dtype=torch.int32)\n    header[0] = 20240803 # magic\n    header[1] = x.size(0) # batch size of the batch, B\n    header[2] = x.size(1) # temporal extent of the batch, T\n    grads = {name: param.grad.cpu() for name, param in model.named_parameters()}\n    with open(filename, \"wb\") as file:\n        # header\n        file.write(header.numpy().tobytes())\n        # input x\n        file.write(x.cpu().numpy().astype(\"int32\").tobytes()) # (B, T)\n        # targets y\n        file.write(y.cpu().numpy().astype(\"int32\").tobytes()) # (B, T)\n        # logits (result of the model forward pass)\n        write_fp32(logits.cpu(), file)\n        # loss (single float, result of the cross entropy loss)\n        write_fp32(loss.cpu(), file)\n        # gradients\n        write_tensors(grads, model.config.n_layer, file, \"float32\")\n    print(f\"wrote {filename}\")\n\n# -----------------------------------------------------------------------------\n# int main\n\ndef print0(*args, **kwargs):\n    # modified print that only prints from the master process\n    # if this is not a distributed run, it's just a print\n    if int(os.environ.get(\"RANK\", 0)) == 0:\n        print(*args, **kwargs)\n\nif __name__ == \"__main__\":\n    print0(f\"Running pytorch {torch.version.__version__}\")\n\n    # default settings will overfit a tiny batch of data\n    # and save model weights and debug state to disk on the first iteration\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--use_hf\", type=int, default=1, help=\"use HuggingFace (default) or use Meta's model\")\n    parser.add_argument(\"--ckpt_dir\", type=str, default=None, help=\"path to llama3 model checkpoint (needed if use_hf=0)\")\n    parser.add_argument(\"--tokenizer_path\", type=str, default=None, help=\"path to llama3 tokenizer (needed if use_hf=0)\")\n    # file system input / output\n    parser.add_argument(\"--input_bin\", type=str, default=\"dev/data/tinyshakespeare/tiny_shakespeare_val.bin\", help=\"input .bin to train on\")\n    parser.add_argument(\"--input_val_bin\", type=str, default=\"\", help=\"input .bin to eval validation loss on\")\n    parser.add_argument(\"--output_dir\", type=str, default=\"\", help=\"output directory to which to write logs and checkpoints\")\n    parser.add_argument(\"--model\", type=str, default=\"meta-llama/Meta-Llama-3.1-8B\", help=\"chose the llama model\")\n    # token layout for each step of the optimization\n    parser.add_argument(\"--batch_size\", type=int, default=4, help=\"batch size, in units of #batch dimensions\")\n    parser.add_argument(\"--sequence_length\", type=int, default=64, help=\"sequence length\")\n    parser.add_argument(\"--total_batch_size\", type=int, default=256, help=\"total desired batch size, in units of #tokens\")\n    # workload (number of steps)\n    parser.add_argument(\"--num_iterations\", type=int, default=10, help=\"number of iterations to run\")\n    parser.add_argument(\"--inference_only\", type=int, default=0, help=\"only run inference\")\n    # optimization\n    parser.add_argument(\"--learning_rate\", type=float, default=1e-5, help=\"learning rate warmup iterations\")\n    parser.add_argument(\"--warmup_iters\", type=int, default=0, help=\"learning rate warmup iterations\")\n    parser.add_argument(\"--learning_rate_decay_frac\", type=float, default=1.0, help=\"learning rate warmup iterations\")\n    parser.add_argument(\"--weight_decay\", type=float, default=0.0, help=\"weight decay\")\n    parser.add_argument(\"--grad_clip\", type=float, default=1.0, help=\"maximum gradient magnitude\")\n    # evaluation\n    parser.add_argument(\"--val_loss_every\", type=int, default=0, help=\"every how mant steps to evaluate val loss?\")\n    parser.add_argument(\"--val_max_steps\", type=int, default=20, help=\"how many batches of val to average?\")\n    parser.add_argument(\"--sample_every\", type=int, default=0, help=\"how often to sample from the model?\")\n    # debugging\n    parser.add_argument(\"--overfit_single_batch\", type=int, default=1, help=\"overfit just one batch of data\")\n    # numerics\n    parser.add_argument(\"--tensorcores\", type=int, default=0, help=\"use tensorcores\")\n    # memory management\n    parser.add_argument(\"--device\", type=str, default=\"\", help=\"by default we autodetect, or set it here\")\n    parser.add_argument(\"--compile\", type=int, default=0, help=\"torch.compile the model\")\n    parser.add_argument(\"--dtype\", type=str, default=\"bfloat16\", help=\"float32|float16|bfloat16\")\n    parser.add_argument(\"--zero_stage\", type=int, default=0, help=\"zero redundancy optimizer stage (0/1/2/3)\")\n    # python -> C bridge\n    parser.add_argument(\"--write_tensors\", type=int, default=0, help=\"write tensors to disk\")\n    args = parser.parse_args()\n\n    # args error checking and convenience variables\n    B, T = args.batch_size, args.sequence_length\n    assert 1 <= T <= 8192, \"sequence length must be between 1 and 8192\"\n    assert args.dtype in {\"float32\", \"float16\", \"bfloat16\"}\n    assert args.model in {\"meta-llama/Meta-Llama-3.1-8B\"}  # only 8B base model supported for now\n\n    # create the logging directory if it does not exist\n    logfile = None\n    if args.output_dir:\n        os.makedirs(args.output_dir, exist_ok=True)\n        logfile = os.path.join(args.output_dir, \"main.log\")\n        # create the log file \"main.log\" inside it, and wipe it clean\n        with open(logfile, \"w\") as f:\n            pass\n\n    # set up DDP (distributed data parallel). torchrun sets this env variable\n    ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?\n    if ddp:\n        # use of DDP atm demands CUDA, we set the device appropriately according to rank\n        assert torch.cuda.is_available(), \"for now i think we need CUDA for DDP\"\n        init_process_group(backend='nccl')\n        ddp_rank = int(os.environ['RANK'])\n        ddp_local_rank = int(os.environ['LOCAL_RANK'])\n        ddp_world_size = int(os.environ['WORLD_SIZE'])\n        device = f'cuda:{ddp_local_rank}'\n        torch.cuda.set_device(device)\n        master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.\n        seed_offset = 0 # each process gets the exact same seed\n        zero_stage = args.zero_stage\n    else:\n        ddp_rank = 0\n        ddp_local_rank = 0\n        zero_stage = 0\n        ddp_world_size = 1\n        master_process = True\n        seed_offset = 0\n        # select the device\n        if args.device:\n            # provided explicitly by the user\n            device = args.device\n        else:\n            # attempt to autodetect the device\n            device = \"cpu\"\n            if torch.cuda.is_available():\n                device = \"cuda\"\n            elif hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n                device = \"mps\"\n    device_type = 'cuda' if 'cuda' in device else 'cpu'\n    assert device_type in {'cuda'}, \"GPU required to run LLaMA 3\"  # we need to load LLaMA as bf16 on CUDA\n    print(f\"using device: {device}\")\n\n    # calculate gradient accumulation from the desired total batch size and the current run configuration\n    tokens_per_fwdbwd = B * T * ddp_world_size\n    assert args.total_batch_size % tokens_per_fwdbwd == 0\n    grad_accum_steps = args.total_batch_size // tokens_per_fwdbwd\n    print0(f\"total desired batch size: {args.total_batch_size}\")\n    print0(f\"=> calculated gradient accumulation steps: {grad_accum_steps}\")\n\n    # set up a context manager following the desired dtype and device\n    ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]\n    ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if (device_type == \"cuda\") else nullcontext()\n\n    # rng / reproducibility\n    torch.manual_seed(42)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed(42)\n\n    # set the torch precision mode to use TensorFloat32 (TF32) for matmuls\n    # docs https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html\n    if args.tensorcores:\n        torch.set_float32_matmul_precision('high')\n\n    # init the model\n    if args.use_hf:\n        model = LLaMA.from_pretrained_llama3_hf(args.model)\n    else:  # use Meta's checkpoint\n        assert args.ckpt_dir is not None and os.path.exists(args.ckpt_dir), f\"llama3 ckpt dir {args.ckpt_dir} does not exist\"\n        assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f\"llama3 tokenizer path {args.tokenizer_path} does not exist\"\n        model = LLaMA.from_pretrained_llama3_meta(args.ckpt_dir, args.tokenizer_path)\n\n    model.train()\n    if args.compile:\n        if hasattr(config, \"coordinate_descent_tuning\"):\n            config.coordinate_descent_tuning = True # suggested by @Chillee\n        print0(\"compiling the model...\")\n        model = torch.compile(model)\n\n    # -------------------------------------------------------------------------\n    # Our own version of a simple DistributedDataLoader\n\n    # load tokens\n    train_loader = DistributedShardedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)\n    val_loader = None\n    if args.input_val_bin:\n        val_loader = DistributedShardedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)\n\n    # -------------------------------------------------------------------------\n    # PyTorch -> C bridge: save some weights and state for C to load later as reference\n\n    # do one forward pass to generate ground truth for our C tests\n    if master_process and args.write_tensors and (not args.inference_only):\n        x, y = train_loader.next_batch()\n        x, y = x.to(device), y.to(device)\n        logits, loss = model(x, y)\n        loss.backward()\n        # save model params, in bfloat16\n        model_to_size = {\"meta-llama/Meta-Llama-3.1-8B\": \"8B\"}\n        model_size_str = model_to_size[args.model] # e.g. \"8B\"\n        write_model(model, os.path.join(args.output_dir, f\"llama3.1_{model_size_str}_bf16.bin\"), dtype=\"bfloat16\")\n        # save x, y, logits, loss, and parameter gradients, for debugging C\n        # always store these in fp32 to have an accurate reference (?)\n        write_state(model, x, y, logits, loss, os.path.join(args.output_dir, f\"llama3_{model_size_str}_debug_state.bin\"))\n        # reset the train_loader for the optimization below\n        train_loader.reset()\n\n    # -------------------------------------------------------------------------\n    # main training loop\n\n    # here we wrap model into DDP container\n    if ddp:\n        model = DDP(model, device_ids=[ddp_local_rank])\n    raw_model = model.module if ddp else model # always contains the \"raw\" unwrapped model\n\n    # init the optimizer\n    optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay,\n                                               learning_rate=args.learning_rate, betas=(0.9, 0.95),\n                                               device_type=device, zero_stage=zero_stage)\n\n    # learning rate decay scheduler (cosine with warmup)\n    def get_lr(it):\n        min_lr = args.learning_rate * args.learning_rate_decay_frac\n        # 1) linear warmup for warmup_iters steps\n        if it < args.warmup_iters:\n            return args.learning_rate * (it+1) / args.warmup_iters\n        # 2) if it > lr_decay_iters, return min learning rate\n        if it > args.num_iterations:\n            return min_lr\n        # 3) in between, use cosine decay down to min learning rate\n        decay_ratio = (it - args.warmup_iters) / (args.num_iterations - args.warmup_iters)\n        assert 0 <= decay_ratio <= 1\n        coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0\n        return min_lr + coeff * (args.learning_rate - min_lr)\n\n    if device == \"cuda\":\n        torch.cuda.reset_peak_memory_stats()\n    timings = []\n    norm = -1.0   # dummy value to print in inference-only mode\n    for step in range(args.num_iterations + 1):\n        t0 = time.time()\n        last_step = (step == args.num_iterations)\n\n        # once in a while evaluate the validation dataset\n        if (args.val_loss_every > 0 \\\n            and (step % args.val_loss_every == 0 or last_step)) \\\n            and (val_loader is not None):\n            model.eval()\n            val_loader.reset()\n            with torch.no_grad():\n                val_loss = 0.0\n                for _ in range(args.val_max_steps):\n                    x, y = val_loader.next_batch()\n                    x, y = x.to(device), y.to(device)\n                    _, loss = model(x, y, return_logits=False)\n                    val_loss += loss.item()\n                val_loss /= args.val_max_steps\n            # log to console and to file\n            print0(f\"val loss {val_loss}\")\n            if master_process and logfile is not None:\n                with open(logfile, \"a\") as f:\n                    f.write(\"s:%d tel:%f\\n\" % (step, val_loss))\n\n        # once in a while perform model inference on the master process\n        if (args.sample_every > 0 \\\n            and (step % args.sample_every == 0 or last_step)) \\\n            and master_process:\n            model.eval()\n            prompts: List[str] = [\n        \"Clearly, the meaning of life is\",\n        \"Simply put, the theory of relativity states that\",\n        \"\"\"The repo llm.c on GitHub is\"\"\",\n        \"\"\"Translate English to French:\n\n        sea otter => loutre de mer\n        peppermint => menthe poivrée\n        plush girafe => girafe peluche\n        cheese =>\"\"\",\n            ]\n            if args.use_hf:\n                prompt_tokens = [model.tokenizer(x).input_ids for x in prompts]\n            else:  # Meta\n                prompt_tokens = [model.tokenizer.encode(x, bos=True, eos=False) for x in prompts]\n\n            generation_tokens = model.generate(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, echo=False)\n            results = [{\"generation\": model.tokenizer.decode(t)} for t in generation_tokens]\n            for prompt, result in zip(prompts, results):\n                print(prompt, end=\"\")\n                print(f\"{result['generation']}\")\n                print(\"\\n==================================\\n\")\n\n        # bit confusing: we want to make sure to eval and sample on 0th iteration\n        # but also after the very last iteration. so we loop for step <= num_iterations\n        # instead of just < num_iterations (one extra due to <=), only to do\n        # the validation/sampling one last time, and then we break right here as we're done.\n        if last_step:\n            break\n\n        # --------------- TRAINING SECTION BEGIN -----------------\n        model.train()\n        optimizer.zero_grad(set_to_none=True)\n        # if we are trying to overfit a single batch, we reset the loader here\n        if args.overfit_single_batch:\n            train_loader.reset()\n        # micro-batch loop where we do gradient accumulation to reach desired total batch size\n        lossf = 0.0 # for getting the mean loss (as simple float) over the accumulation steps\n        for micro_step in range(grad_accum_steps):\n            # fetch a batch\n            x, y = train_loader.next_batch()\n            x, y = x.to(device), y.to(device)\n            if ddp:\n                # we want only the last micro-step to sync grads in a DDP model\n                # the official way to do this is with model.no_sync(), but that is a\n                # context manager that bloats the code, so we just toggle this variable\n                model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)\n            # forward pass\n            with ctx:\n                _, loss = model(x, y, return_logits=False)\n                # we have to scale the loss to account for gradient accumulation,\n                # because the gradients just add on each successive backward().\n                # addition of gradients corresponds to a SUM in the objective, but\n                # instead of a SUM we want MEAN, so we scale the loss here\n                loss = loss / grad_accum_steps\n                lossf += loss.detach() # keep track of the mean loss\n            # backward pass\n            if not args.inference_only:\n                loss.backward()\n        if ddp:\n            dist.all_reduce(lossf, op=dist.ReduceOp.AVG)\n        lossf = lossf.item()\n        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)\n        # determine and set the learning rate for this iteration\n        lr = get_lr(step)\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = lr\n        # step the optimizer\n        optimizer.step()\n        # --------------- TRAINING SECTION END -------------------\n        # everything that follows now is just diagnostics, prints, logging, etc.\n\n        # wait on the CPU for all device work to end so we get accurate per-iteration timings below\n        if device == \"mps\":\n            torch.mps.synchronize()\n        elif device == \"cuda\":\n            torch.cuda.synchronize()\n        # time and print\n        t1 = time.time()\n        # the 0th iteration is often an outlier (much slower) => skip logging it\n        tokens_per_second = grad_accum_steps * ddp_world_size * B * T / (t1-t0)\n        print0(f\"step {step+1:4d}/{args.num_iterations} | train loss {lossf:.6f} | norm {norm:.4f} | lr {lr:.2e} | ({(t1-t0)*1000:.2f} ms | {tokens_per_second:.0f} tok/s)\")\n        # log to logile\n        if master_process and logfile is not None:\n            with open(logfile, \"a\") as f:\n                f.write(\"s:%d trl:%f\\n\" % (step, lossf))\n\n        # keep track of smooth timings, last 20 iterations\n        if step > 0 and step > args.num_iterations - 20:\n            timings.append(t1-t0)\n\n    # print the average of the last 20 timings, to get something smooth-ish\n    timings = timings[-20:]\n    print0(f\"final {len(timings)} iters avg: {np.mean(timings)*1000:.3f}ms\")\n    print0(f\"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB\")\n\n    # -------------------------------------------------------------------------\n    # clean up nice\n    if ddp:\n        destroy_process_group()\n"
  }
]