[
  {
    "path": ".gitignore",
    "content": "*.pyc\nadd_license.py"
  },
  {
    "path": "README.md",
    "content": "# AMD-135M\nThis repository provides the implementation for training AMD-135M models and is based on [TinyLlama](https://github.com/jzhang38/TinyLlama).\n\nAMD-135M is a language model trained on AMD MI250 GPUs. Based on LLaMA2 model architecture, this model can be smoothly loaded as LlamaForCausalLM with huggingface transformers. Furthermore, we use the same tokenizer as LLaMA2, enableing it to be a draft model of speculative decoding for LLaMA2 and CodeLlama.\n\n### Docker image\nPlease use the following rocm docker in [docker hub](https://hub.docker.com/layers/rocm/pytorch/rocm6.1_ubuntu20.04_py3.9_pytorch_2.3.0_preview/images/sha256-0136f3e678290e0ae78cdd78c90d9f849ee3ac3602864c486e0252f8f8b9662b?context=explore) \n\n`docker pull rocm/pytorch:rocm6.1_ubuntu20.04_py3.9_pytorch_2.3.0_preview`\n\n### Python packages dependency\nPlease run `pip install -r requirement.txt` to install extra python packages based on the docker above.\n\n### Dataset\nStep 1, download [SlimPajama-627](https://huggingface.co/datasets/cerebras/SlimPajama-627B), [project gutenberg](https://huggingface.co/datasets/manu/project_gutenberg) and [StarCoder](https://huggingface.co/datasets/bigcode/starcoderdata).\n\n```bash\ngit clone https://huggingface.co/datasets/cerebras/SlimPajama-627B\ngit clone https://huggingface.co/datasets/manu/project_gutenberg\ngit clone https://huggingface.co/datasets/bigcode/starcoderdata\n```\n\nStep 2, process the text data into token ids. And you will find the processed dataset at `./slim_processed`, `./slim_validation_processed` and `./starcoderdata_python_processed`.\n\n```bash\n# For pretraining\nbash ./scripts/prepare_slimpajama_train.sh\nbash ./scripts/prepare_project_gutenberg.sh\n# For validation\nbash ./scripts/prepare_slimpajama_valid.sh\n# For code finetuning\nbash ./scripts/prepare_starcoder_python.sh\n```\n\n### Pretraining\nTo train a tinyllama model, please run the following scripts on 4 nodes, 4 MI250 GPUs (8 vitural devices) for each node.\n\n```bash\n# run on node 0.\nbash ./cluster/pretrain_node_0.sh\n# run on node 1.\nbash ./cluster/pretrain_node_1.sh\n# run on node 2.\nbash ./cluster/pretrain_node_2.sh\n# run on node 3.\nbash ./cluster/pretrain_node_3.sh\n```\n\n### Code Finetuning\nTo finetune a tinyllama model, please run the following script.\n\n```bash\nbash ./cluster/finetune.sh\n```\n\n### Evaluation\nWe evaluate AMD-Llama-135m using [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) on popular NLP benchmarks and results are listed as follows.\n\n| **Model**            | **SciQ**      | **WinoGrande** | **PIQA**      | **WSC**       | **MMLU**      | **Lambada (OpenAI)** | **ARC - Easy** | **ARC - Challenge** | **LogiQA**    | **Hellaswag** |\n|----------------------|---------------|----------------|---------------|---------------|---------------|----------------------|----------------|---------------------|---------------|---------------|\n| GPT2-124M (small)    | 0.753±0.0136  | 0.5162±0.0140  | 0.6289±0.0113 | 0.4327±0.0488 | 0.2292±0.0383 | 0.3256±0.0065        | 0.4381±0.0102  | 0.1903±0.0115       | 0.2181±0.0162 | 0.2892±0.0045 |\n| OPT-125M             | 0.751±0.014   | 0.503±0.014    | 0.630±0.011   | 0.365±0.047   | 0.229±0.038   | 0.379±0.007          | 0.436±0.010    | 0.191±0.012         | 0.229±0.016   | 0.292±0.004   |\n| JackFram/llama-68m   | 0.652±0.0151  | 0.513±0.014    | 0.6197±0.0113 | 0.4038±0.0483 | 0.2302±0.0035 | 0.1351±0.0048        | 0.3864±0.0100  | 0.1792±0.0112       | 0.2273±0.0164 | 0.2790±0.0045 |\n| JackFram/llama-160m  | 0.724±0.0141  | 0.5012±0.0141  | 0.6605±0.011  | 0.3654±0.0474 | 0.2299±0.0035 | 0.3134±0.0065        | 0.4335±0.0102  | 0.1980±0.0116       | 0.2197±0.0162 | 0.3094±0.0046 |\n| [AMD-Llama-135m](https://huggingface.co/amd/AMD-Llama-135m)       | 0.761±0.0135  | 0.5012±0.0141  | 0.6420±0.0112 | 0.3654±0.0474 | 0.2302±0.0035 | 0.3330±0.0066        | 0.4364±0.0102  | 0.1911±0.0115       | 0.2120±0.0160 | 0.3048±0.0046 |\n\n\n### Speculative Decoding\nTo run speculative decoding using AMD-Llama-135m-code as draft model for CodeLlama-7b on [Humaneval](https://huggingface.co/datasets/openai_humaneval) dataset, please run the following script.\n\n```bash\n# Need add some logs for huggingface transformers==4.37.2 to calculate the acceptance rate of speculative decoding.\npatch -u /path/to/transformers/generation/utils.py -i ./speculative_decoding/utils.patch\nbash ./speculative_decoding/codellama_spec.sh\n```\n\nWe evaluate performance of decoding with target model only and speculative decoding on MI250 GPU and Ryzen AI CPU (with NPU kernel). All experiments are run on Humaneval dataset.\n\n| Target Model Device   | Draft Model Device   | Do Randomly Sampling   | Target model Humaneval Pass@1 | Speculative Decoding Humaneval Pass@1 | Acceptance Rate | Throughput Speedup |\n|:----------------------|:---------------------|:-----------------------|-------------------------------:|---------------------------------------:|----------------:|-------------------:|\n| FP32 MI250            | FP32 MI250           | TRUE                   | 32.31%                        | 29.27%                                | 0.650355        | 2.58x              |\n| FP32 MI250            | FP32 MI250           | FALSE                  | 31.10%                        | 31.10%                                | 0.657839        | **2.80x**          |\n| BF16 MI250            | BF16 MI250           | TRUE                   | 31.10%                        | 31.10%                                | 0.668822        | 1.67x              |\n| BF16 MI250            | BF16 MI250           | FALSE                  | 34.15%                        | 33.54%                                | 0.665497        | 1.75x              |\n| INT4 NPU              | BF16 CPU             | TRUE                   | 28.05%                        | 30.49%                                | 0.722913        | 2.83x              |\n| INT4 NPU              | BF16 CPU             | FALSE                  | 28.66%                        | 28.66%                                | 0.738072        | **2.98x**          |\n| BF16 CPU              | BF16 CPU             | TRUE                   | 31.10%                        | 31.71%                                | 0.723971        | 3.68x              |\n| BF16 CPU              | BF16 CPU             | FALSE                  | 33.54%                        | 33.54%                                | 0.727548        | **3.88x**          |\n| FP32 CPU              | FP32 CPU             | TRUE                   | 29.87%                        | 28.05%                                | 0.727214        | 3.57x              |\n| FP32 CPU              | FP32 CPU             | FALSE                  | 31.10%                        | 31.10%                                | 0.738641        | 3.66x              |\n\n\n## Training and finetuning cost\nIt takes 6 days to pretrain AMD-Llama-135m on 4 MI250 nodes each of which has 4 MI250 GPUs (8 virtual GPU cards, 64G memory for each). \nIt takes 4 days to finetune AMD-Llama-135m-code on 4 MI250 GPUs. \nIt takes 11T disk space to store raw and processed SlimPajama, project gutenberg and Starcoder datasets.\n\n\n#### ROCM\n```\nVersion: 6.1.2.60102-119~20.04\nPriority: optional\nSection: devel\nMaintainer: ROCm Dev Support <rocm-dev.support@amd.com>\nInstalled-Size: 13.3 kB\nDepends: hipblas (= 2.1.0.60102-119~20.04), hipblaslt (= 0.7.0.60102-119~20.04), hipfft (= 1.0.14.60102-119~20.04), hipsolver (= 2.1.1.60102-119~20.04), hipsparse (= 3.0.1.60102-119~20.04), hiptensor (= 1.2.0.60102-119~20.04), miopen-hip (= 3.1.0.60102-119~20.04), half (= 1.12.0.60102-119~20.04), rccl (= 2.18.6.60102-119~20.04), rocalution (= 3.1.1.60102-119~20.04), rocblas (= 4.1.2.60102-119~20.04), rocfft (= 1.0.27.60102-119~20.04), rocrand (= 3.0.1.60102-119~20.04), hiprand (= 2.10.16.60102-119~20.04), rocsolver (= 3.25.0.60102-119~20.04), rocsparse (= 3.1.2.60102-119~20.04), rocm-core (= 6.1.2.60102-119~20.04), hipsparselt (= 0.2.0.60102-119~20.04), composablekernel-dev (= 1.1.0.60102-119~20.04), hipblas-dev (= 2.1.0.60102-119~20.04), hipblaslt-dev (= 0.7.0.60102-119~20.04), hipcub-dev (= 3.1.0.60102-119~20.04), hipfft-dev (= 1.0.14.60102-119~20.04), hipsolver-dev (= 2.1.1.60102-119~20.04), hipsparse-dev (= 3.0.1.60102-119~20.04), hiptensor-dev (= 1.2.0.60102-119~20.04), miopen-hip-dev (= 3.1.0.60102-119~20.04), rccl-dev (= 2.18.6.60102-119~20.04), rocalution-dev (= 3.1.1.60102-119~20.04), rocblas-dev (= 4.1.2.60102-119~20.04), rocfft-dev (= 1.0.27.60102-119~20.04), rocprim-dev (= 3.1.0.60102-119~20.04), rocrand-dev (= 3.0.1.60102-119~20.04), hiprand-dev (= 2.10.16.60102-119~20.04), rocsolver-dev (= 3.25.0.60102-119~20.04), rocsparse-dev (= 3.1.2.60102-119~20.04), rocthrust-dev (= 3.0.1.60102-119~20.04), rocwmma-dev (= 1.4.0.60102-119~20.04), hipsparselt-dev (= 0.2.0.60102-119~20.04)\nHomepage:\nhttps://github.com/RadeonOpenCompute/ROCm\nDownload-Size: 1064 B\nAPT-Manual-Installed: yes\nAPT-Sources:\nhttp://repo.radeon.com/rocm/apt/6.1.2\nfocal/main amd64 Packages\nDescription: Radeon Open Compute (ROCm) Runtime software stack\n```\n### System info\n```\nUbuntu 22.04.3 LTS\nRelease:        22.04\nCodename:       jammy\n\nLinux version 5.15.0-88-generic (buildd@lcy02-amd64-058) (gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0, GNU ld (GNU Binutils for Ubuntu) 2.38) #98-Ubuntu SMP Mon Oct 2 15:18:56 UTC 2023\n\nLinux sjc144-canary-node035.dcgpu.amd.com 5.15.0-88-generic #98-Ubuntu SMP Mon Oct 2 15:18:56 UTC 2023 x86_64 x86_64 x86_64 GNU/Linux\n```\n#### License\nCopyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n"
  },
  {
    "path": "cluster/finetune.sh",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nexport TRAIN_DATA_PATH=/path/to/preprocessed/training/data\nexport VALID_DATA_PATH=/path/to/preprocessed/validation/data\nexport CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7'\nexport BASE_MODEL_PATH=/path/to/base/model/ckpt\nexport NCCL_SOCKET_IFNAME={network interface name}\nexport MASTER_ADDRESS={master node ip}\nexport MAIN_OPRT={port}\n\nMODEL_NAME='tiny_LLaMA_135M_2k'\nlightning run model \\\n    --node-rank=0  \\\n    --main-address=$MASTER_ADDRESS \\\n    --accelerator=cuda \\\n    --devices=8 \\\n    --num-nodes=1 \\\n    --main-port=$MAIN_OPRT \\\n    pretrain/tinyllama_code.py --devices 8 --train_data_dir $TRAIN_DATA_PATH  --val_data_dir $VALID_DATA_PATH --model_name $MODEL_NAME \\\n    --checkpoint_path $BASE_MODEL_PATH\n"
  },
  {
    "path": "cluster/pretrain.sh",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nexport TRAIN_DATA_PATH=/path/to/preprocessed/training/data\nexport VALID_DATA_PATH=/path/to/preprocessed/validation/data\nexport CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7'\nexport NCCL_SOCKET_IFNAME={network interface name}\nexport MASTER_ADDRESS={master node ip}\nexport MAIN_OPRT={port}\n\nMODEL_NAME='tiny_LLaMA_135M_2k'\nlightning run model \\\n    --node-rank=0  \\\n    --main-address=$MASTER_ADDRESS \\\n    --accelerator=cuda \\\n    --devices=8 \\\n    --num-nodes=1 \\\n    --main-port=$MAIN_OPRT \\\n    pretrain/tinyllama.py --devices 8 --train_data_dir $TRAIN_DATA_PATH  --val_data_dir $VALID_DATA_PATH --model_name $MODEL_NAME\\\n"
  },
  {
    "path": "cluster/pretrain_node_0.sh",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nexport TRAIN_DATA_PATH=/path/to/preprocessed/training/data\nexport VALID_DATA_PATH=/path/to/preprocessed/validation/data\nexport CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7'\nexport NCCL_SOCKET_IFNAME={network interface name}\nexport MASTER_ADDRESS={master node ip}\nexport MAIN_OPRT={port}\n\nMODEL_NAME='tiny_LLaMA_135M_2k'\nlightning run model \\\n    --node-rank=0  \\\n    --main-address=$MASTER_ADDRESS \\\n    --accelerator=cuda \\\n    --devices=8 \\\n    --num-nodes=4 \\\n    --main-port=$MAIN_OPRT \\\n    pretrain/tinyllama.py --precision 'bf16-mixed' --devices 8 --train_data_dir $TRAIN_DATA_PATH  --val_data_dir $VALID_DATA_PATH --model_name $MODEL_NAME\\\n"
  },
  {
    "path": "cluster/pretrain_node_1.sh",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nexport TRAIN_DATA_PATH=/path/to/preprocessed/training/data\nexport VALID_DATA_PATH=/path/to/preprocessed/validation/data\nexport CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7'\nexport NCCL_SOCKET_IFNAME={network interface name}\nexport MASTER_ADDRESS={master node ip}\nexport MAIN_OPRT={port}\n\nMODEL_NAME='tiny_LLaMA_135M_2k'\nlightning run model \\\n    --node-rank=1  \\\n    --main-address=$MASTER_ADDRESS \\\n    --accelerator=cuda \\\n    --devices=8 \\\n    --num-nodes=4 \\\n    --main-port=$MAIN_OPRT \\\n    pretrain/tinyllama.py --precision 'bf16-mixed' --devices 8 --train_data_dir $TRAIN_DATA_PATH  --val_data_dir $VALID_DATA_PATH --model_name $MODEL_NAME\\\n"
  },
  {
    "path": "cluster/pretrain_node_2.sh",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nexport TRAIN_DATA_PATH=/path/to/preprocessed/training/data\nexport VALID_DATA_PATH=/path/to/preprocessed/validation/data\nexport CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7'\nexport NCCL_SOCKET_IFNAME={network interface name}\nexport MASTER_ADDRESS={master node ip}\nexport MAIN_OPRT={port}\n\nMODEL_NAME='tiny_LLaMA_135M_2k'\nlightning run model \\\n    --node-rank=2  \\\n    --main-address=$MASTER_ADDRESS \\\n    --accelerator=cuda \\\n    --devices=8 \\\n    --num-nodes=4 \\\n    --main-port=$MAIN_OPRT \\\n    pretrain/tinyllama.py --precision 'bf16-mixed' --devices 8 --train_data_dir $TRAIN_DATA_PATH  --val_data_dir $VALID_DATA_PATH --model_name $MODEL_NAME\\\n"
  },
  {
    "path": "cluster/pretrain_node_3.sh",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nexport TRAIN_DATA_PATH=/path/to/preprocessed/training/data\nexport VALID_DATA_PATH=/path/to/preprocessed/validation/data\nexport CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7'\nexport NCCL_SOCKET_IFNAME={network interface name}\nexport MASTER_ADDRESS={master node ip}\nexport MAIN_OPRT={port}\n\nMODEL_NAME='tiny_LLaMA_135M_2k'\nlightning run model \\\n    --node-rank=3  \\\n    --main-address=$MASTER_ADDRESS \\\n    --accelerator=cuda \\\n    --devices=8 \\\n    --num-nodes=4 \\\n    --main-port=$MAIN_OPRT \\\n    pretrain/tinyllama.py --precision 'bf16-mixed' --devices 8 --train_data_dir $TRAIN_DATA_PATH  --val_data_dir $VALID_DATA_PATH --model_name $MODEL_NAME\\\n"
  },
  {
    "path": "lit_gpt/__init__.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom lit_gpt.model import GPT\nfrom lit_gpt.config import Config\nfrom lit_gpt.tokenizer import Tokenizer\nfrom lightning_utilities.core.imports import RequirementCache\n\n__all__ = [\"GPT\", \"Config\", \"Tokenizer\"]\n"
  },
  {
    "path": "lit_gpt/adapter.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\n\"\"\"Implementation of the paper:\n\nLLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention\nhttps://arxiv.org/abs/2303.16199\n\nPort for Lit-GPT\n\"\"\"\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nfrom typing_extensions import Self\n\nfrom lit_gpt.config import Config as BaseConfig\nfrom lit_gpt.model import GPT as BaseModel\nfrom lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention\nfrom lit_gpt.model import KVCache, RoPECache, apply_rope\n\n\n@dataclass\nclass Config(BaseConfig):\n    adapter_prompt_length: int = 10\n    adapter_start_layer: int = 2\n\n\nclass GPT(BaseModel):\n    \"\"\"The implementation is identical to `lit_gpt.model.GPT` with the exception that\n    the `Block` saves the layer index and passes it down to the attention layer.\"\"\"\n\n    def __init__(self, config: Config) -> None:\n        nn.Module.__init__(self)\n        assert config.padded_vocab_size is not None\n        self.config = config\n\n        self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)\n        self.transformer = nn.ModuleDict(\n            dict(\n                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),\n                h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),\n                ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),\n            )\n        )\n\n        self.rope_cache: Optional[RoPECache] = None\n        self.mask_cache: Optional[torch.Tensor] = None\n        self.kv_caches: List[KVCache] = []\n        self.adapter_kv_caches: List[KVCache] = []\n\n    def reset_cache(self) -> None:\n        super().reset_cache()\n        self.adapter_kv_caches.clear()\n\n    def forward(\n        self,\n        idx: torch.Tensor,\n        max_seq_length: Optional[int] = None,\n        input_pos: Optional[torch.Tensor] = None,\n        lm_head_chunk_size: int = 0,\n    ) -> Union[torch.Tensor, List[torch.Tensor]]:\n        B, T = idx.size()\n        use_kv_cache = input_pos is not None\n\n        block_size = self.config.block_size\n        if max_seq_length is None:\n            max_seq_length = block_size\n        if use_kv_cache:  # not relevant otherwise\n            assert (\n                max_seq_length >= T\n            ), f\"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}\"\n        assert max_seq_length <= block_size, f\"Cannot attend to {max_seq_length}, block size is only {block_size}\"\n        assert block_size >= T, f\"Cannot forward sequence of length {T}, block size is only {block_size}\"\n\n        if self.rope_cache is None:\n            self.rope_cache = self.build_rope_cache(idx)\n        # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask\n        # for the kv-cache support (only during inference), we only create it in that situation\n        # this will be resolved by https://github.com/pytorch/pytorch/issues/96099\n        if use_kv_cache and self.mask_cache is None:\n            self.mask_cache = self.build_mask_cache(idx)\n\n        cos, sin = self.rope_cache\n        if use_kv_cache:\n            cos = cos.index_select(0, input_pos)\n            sin = sin.index_select(0, input_pos)\n            mask = self.mask_cache.index_select(2, input_pos)\n            mask = mask[:, :, :, :max_seq_length]\n        else:\n            cos = cos[:T]\n            sin = sin[:T]\n            mask = None\n\n        # forward the model itself\n        x = self.transformer.wte(idx)  # token embeddings of shape (b, t, n_embd)\n\n        if not use_kv_cache:\n            for block in self.transformer.h:\n                x, *_ = block(x, (cos, sin), max_seq_length)\n        else:\n            self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1))\n            self.adapter_kv_caches = self.adapter_kv_caches or [None for _ in range(self.config.n_layer)]\n            for i, block in enumerate(self.transformer.h):\n                x, self.kv_caches[i], self.adapter_kv_caches[i] = block(\n                    x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i], self.adapter_kv_caches[i]\n                )\n\n        x = self.transformer.ln_f(x)\n\n        if lm_head_chunk_size > 0:\n            # chunk the lm head logits to reduce the peak memory used by autograd\n            return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]\n        return self.lm_head(x)  # (b, t, vocab_size)\n\n    @classmethod\n    def from_name(cls, name: str, **kwargs: Any) -> Self:\n        return cls(Config.from_name(name, **kwargs))\n\n    def _init_weights(self, module: nn.Module) -> None:\n        \"\"\"Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.\"\"\"\n        super()._init_weights(module)\n        if isinstance(module, CausalSelfAttention):\n            module.reset_parameters()\n\n\nclass Block(nn.Module):\n    \"\"\"The implementation is identical to `lit_gpt.model.Block` with the exception that\n    we replace the attention layer where adaption is implemented.\"\"\"\n\n    def __init__(self, config: Config, block_idx: int) -> None:\n        super().__init__()\n        self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)\n        self.attn = CausalSelfAttention(config, block_idx)\n        if not config.shared_attention_norm:\n            self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)\n        self.mlp = config.mlp_class(config)\n\n        self.config = config\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        rope: RoPECache,\n        max_seq_length: int,\n        mask: Optional[torch.Tensor] = None,\n        input_pos: Optional[torch.Tensor] = None,\n        kv_cache: Optional[KVCache] = None,\n        adapter_kv_cache: Optional[KVCache] = None,\n    ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]:\n        n_1 = self.norm_1(x)\n        h, new_kv_cache, new_adapter_kv_cache = self.attn(\n            n_1, rope, max_seq_length, mask, input_pos, kv_cache, adapter_kv_cache\n        )\n        if self.config.parallel_residual:\n            n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)\n            x = x + h + self.mlp(n_2)\n        else:\n            if self.config.shared_attention_norm:\n                raise NotImplementedError(\n                    \"No checkpoint amongst the ones we support uses this configuration\"\n                    \" (non-parallel residual and shared attention norm).\"\n                )\n            x = x + h\n            x = x + self.mlp(self.norm_2(x))\n        return x, new_kv_cache, new_adapter_kv_cache\n\n\nclass CausalSelfAttention(BaseCausalSelfAttention):\n    \"\"\"A modification of `lit_gpt.model.CausalSelfAttention` that adds the attention\n    over the adaption prompt.\"\"\"\n\n    def __init__(self, config: Config, block_idx: int) -> None:\n        super().__init__(config)\n        if block_idx >= config.adapter_start_layer:\n            # adapter embedding layer\n            self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)\n            # gate for adaption\n            self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))\n            self.reset_parameters()\n        self.block_idx = block_idx\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        rope: RoPECache,\n        max_seq_length: int,\n        mask: Optional[torch.Tensor] = None,\n        input_pos: Optional[torch.Tensor] = None,\n        kv_cache: Optional[KVCache] = None,\n        adapter_kv_cache: Optional[KVCache] = None,\n    ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]:\n        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)\n\n        qkv = self.attn(x)\n\n        # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)\n        q_per_kv = self.config.n_head // self.config.n_query_groups\n        total_qkv = q_per_kv + 2  # each group has 1+ queries, 1 key, and 1 value\n        qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)\n        qkv = qkv.permute(0, 2, 3, 1, 4)  # (B, n_query_groups, total_qkv, T, hs)\n\n        # split batched computation into three\n        q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)\n\n        # repeat k and v if necessary\n        if self.config.n_query_groups != 1:  # doing this would require a full kv cache with MQA (inefficient!)\n            # for MHA this is a no-op\n            k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)\n            v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)\n\n        q = q.reshape(B, -1, T, self.config.head_size)  # (B, nh_q, T, hs)\n        k = k.reshape(B, -1, T, self.config.head_size)  # (B, nh_k, T, hs)\n        v = v.reshape(B, -1, T, self.config.head_size)  # (B, nh_v, T, hs)\n\n        n_elem = int(self.config.rotary_percentage * self.config.head_size)\n\n        cos, sin = rope\n        q_roped = apply_rope(q[..., :n_elem], cos, sin)\n        k_roped = apply_rope(k[..., :n_elem], cos, sin)\n        q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)\n        k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)\n\n        if kv_cache is not None:\n            cache_k, cache_v = kv_cache\n            cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype)\n            # check if reached token limit\n            if input_pos[-1] >= max_seq_length:\n                input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)\n                # shift 1 position to the left\n                cache_k = torch.roll(cache_k, -1, dims=2)\n                cache_v = torch.roll(cache_v, -1, dims=2)\n            k = cache_k.index_copy_(2, input_pos, k)\n            v = cache_v.index_copy_(2, input_pos, v)\n            kv_cache = k, v\n\n        y = self.scaled_dot_product_attention(q, k, v, mask=mask)\n\n        if self.block_idx >= self.config.adapter_start_layer:\n            aT = self.config.adapter_prompt_length\n            if adapter_kv_cache is not None:\n                ak, av = adapter_kv_cache\n            else:\n                prefix = self.adapter_wte.weight.reshape(1, aT, C)\n                aqkv = self.attn(prefix)\n                aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)\n                aqkv = aqkv.permute(0, 2, 3, 1, 4)\n                _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)\n                if self.config.n_query_groups != 1:\n                    # for MHA this is a no-op\n                    ak = ak.repeat_interleave(q_per_kv, dim=2)\n                    av = av.repeat_interleave(q_per_kv, dim=2)\n                ak = ak.view(1, -1, aT, self.config.head_size)  # (1, nh_ak, aT, hs)\n                av = av.view(1, -1, aT, self.config.head_size)  # (1, nh_av, aT, hs)\n                adapter_kv_cache = (ak, av)\n\n            amask = torch.ones(T, aT, dtype=torch.bool, device=x.device)\n            ay = self.scaled_dot_product_attention(q, ak, av, amask)\n            y = y + self.gating_factor * ay\n\n        y = y.reshape(B, T, C)  # re-assemble all head outputs side by side\n\n        # output projection\n        y = self.proj(y)\n\n        return y, kv_cache, adapter_kv_cache\n\n    def reset_parameters(self) -> None:\n        torch.nn.init.zeros_(self.gating_factor)\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with older checkpoints.\"\"\"\n        if (key := prefix + \"gating_factor\") in state_dict and state_dict[key].size(1) == self.config.n_head:\n            state_dict[key] = state_dict[key].permute(0, 2, 1, 3)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\ndef mark_only_adapter_as_trainable(model: GPT) -> None:\n    \"\"\"Sets `requires_grad=False` for all non-adapter weights.\"\"\"\n    for name, param in model.named_parameters():\n        param.requires_grad = adapter_filter(name, param)\n\n\ndef adapter_filter(key: str, value: Any) -> bool:\n    return \"adapter_wte\" in key or \"gating_factor\" in key\n"
  },
  {
    "path": "lit_gpt/adapter_v2.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\n\"\"\"Implementation of the paper:\n\nLLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model\nhttps://arxiv.org/abs/2304.15010\n\nPort for Lit-GPT\n\"\"\"\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Type\n\nimport torch\nimport torch.nn as nn\nfrom typing_extensions import Self\n\nimport lit_gpt\nfrom lit_gpt.adapter import GPT as BaseModel\nfrom lit_gpt.adapter import Block as BaseBlock\nfrom lit_gpt.adapter import Config as BaseConfig\nfrom lit_gpt.adapter import KVCache, RoPECache\nfrom lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention\nfrom lit_gpt.model import apply_rope\nfrom lit_gpt.utils import map_old_state_dict_weights\n\n\n@dataclass\nclass Config(BaseConfig):\n    @property\n    def mlp_class(self) -> Type:\n        return getattr(lit_gpt.adapter_v2, self._mlp_class)\n\n\ndef adapter_filter(key: str, value: Any) -> bool:\n    adapter_substrings = (\n        # regular adapter v1 parameters\n        \"adapter_wte\",\n        \"gating_factor\",\n        # adapter v2: new bias and scale used in Linear\n        \"adapter_scale\",\n        \"adapter_bias\",\n        # adapter v2: Norm parameters are now trainable\n        \"norm_1\",\n        \"norm_2\",\n        \"ln_f\",\n    )\n    return any(s in key for s in adapter_substrings)\n\n\nclass AdapterV2Linear(torch.nn.Module):\n    def __init__(self, in_features: int, out_features: int, **kwargs) -> None:\n        super().__init__()\n        self.linear = torch.nn.Linear(in_features, out_features, **kwargs)\n        self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False)\n        self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False)\n        self.reset_parameters()\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return self.adapter_scale * (self.linear(x) + self.adapter_bias)\n\n    def reset_parameters(self) -> None:\n        nn.init.zeros_(self.adapter_bias)\n        nn.init.ones_(self.adapter_scale)\n\n\nclass GPT(BaseModel):\n    def __init__(self, config: Config) -> None:\n        # Skip the parent class __init__ altogether and replace it to avoid useless allocations\n        nn.Module.__init__(self)\n        assert config.padded_vocab_size is not None\n        self.config = config\n\n        self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=False)\n        self.transformer = nn.ModuleDict(\n            dict(\n                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),\n                h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),\n                ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),\n            )\n        )\n\n        self.rope_cache: Optional[RoPECache] = None\n        self.mask_cache: Optional[torch.Tensor] = None\n        self.kv_caches: List[KVCache] = []\n        self.adapter_kv_caches: List[KVCache] = []\n\n    @classmethod\n    def from_name(cls, name: str, **kwargs: Any) -> Self:\n        return cls(Config.from_name(name, **kwargs))\n\n    def _init_weights(self, module: nn.Module) -> None:\n        \"\"\"Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.\"\"\"\n        super()._init_weights(module)\n        if isinstance(module, CausalSelfAttention):\n            module.reset_parameters()\n        if isinstance(module, AdapterV2Linear):\n            module.reset_parameters()\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base checkpoints.\"\"\"\n        mapping = {\"lm_head.weight\": \"lm_head.linear.weight\"}\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\nclass Block(BaseBlock):\n    \"\"\"The implementation is identical to `lit_gpt.model.Block` with the exception that\n    we replace the attention layer where adaption is implemented.\"\"\"\n\n    def __init__(self, config: Config, block_idx: int) -> None:\n        # Skip the parent class __init__ altogether and replace it to avoid useless allocations\n        nn.Module.__init__(self)\n        self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)\n        self.attn = CausalSelfAttention(config, block_idx)\n        if not config.shared_attention_norm:\n            self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)\n        self.mlp = config.mlp_class(config)\n\n        self.config = config\n\n\nclass CausalSelfAttention(BaseCausalSelfAttention):\n    def __init__(self, config: Config, block_idx: int) -> None:\n        \"\"\"Causal self-attention with calculating qkv matrices with a single matrix* and Low Ranking Adaptation for\n        parameter-efficient fine-tuning.\n\n        *Instead of creating multiple heads and concatenating the result (in addition to creating separate matrices for\n        query, key and value for each head) we can do this in a single pass with a single weight matrix.\n        \"\"\"\n        # Skip the parent class __init__ altogether and replace it to avoid useless allocations\n        nn.Module.__init__(self)\n        shape = (config.n_head + 2 * config.n_query_groups) * config.head_size\n        # key, query, value projections for all heads, but in a batch\n        self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias)\n        # output projection\n        self.proj = AdapterV2Linear(config.n_embd, config.n_embd, bias=config.bias)\n        if block_idx >= config.adapter_start_layer:\n            # adapter embedding layer\n            self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)\n            # gate for adaption\n            self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))\n            self.reset_parameters()\n        self.block_idx = block_idx\n\n        self.config = config\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        rope: RoPECache,\n        max_seq_length: int,\n        mask: Optional[torch.Tensor] = None,\n        input_pos: Optional[torch.Tensor] = None,\n        kv_cache: Optional[KVCache] = None,\n        adapter_kv_cache: Optional[KVCache] = None,\n    ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]:\n        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)\n\n        qkv = self.attn(x)\n\n        # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)\n        q_per_kv = self.config.n_head // self.config.n_query_groups\n        total_qkv = q_per_kv + 2  # each group has 1+ queries, 1 key, and 1 value\n        qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)\n        qkv = qkv.permute(0, 2, 3, 1, 4)  # (B, n_query_groups, total_qkv, T, hs)\n\n        # split batched computation into three\n        q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)\n\n        # repeat k and v if necessary\n        if self.config.n_query_groups != 1:  # doing this would require a full kv cache with MQA (inefficient!)\n            # for MHA this is a no-op\n            k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)\n            v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)\n\n        q = q.reshape(B, -1, T, self.config.head_size)  # (B, nh_q, T, hs)\n        k = k.reshape(B, -1, T, self.config.head_size)  # (B, nh_k, T, hs)\n        v = v.reshape(B, -1, T, self.config.head_size)  # (B, nh_v, T, hs)\n\n        n_elem = int(self.config.rotary_percentage * self.config.head_size)\n\n        cos, sin = rope\n        q_roped = apply_rope(q[..., :n_elem], cos, sin)\n        k_roped = apply_rope(k[..., :n_elem], cos, sin)\n        q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)\n        k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)\n\n        if kv_cache is not None:\n            cache_k, cache_v = kv_cache\n            cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype)\n            # check if reached token limit\n            if input_pos[-1] >= max_seq_length:\n                input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)\n                # shift 1 position to the left\n                cache_k = torch.roll(cache_k, -1, dims=2)\n                cache_v = torch.roll(cache_v, -1, dims=2)\n            k = cache_k.index_copy_(2, input_pos, k)\n            v = cache_v.index_copy_(2, input_pos, v)\n            kv_cache = k, v\n\n        y = self.scaled_dot_product_attention(q, k, v, mask=mask)\n\n        if self.block_idx >= self.config.adapter_start_layer:\n            aT = self.config.adapter_prompt_length\n            if adapter_kv_cache is not None:\n                ak, av = adapter_kv_cache\n            else:\n                prefix = self.adapter_wte.weight.reshape(1, aT, C)\n                aqkv = self.attn(prefix)\n                aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)\n                aqkv = aqkv.permute(0, 2, 3, 1, 4)\n                _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)\n                if self.config.n_query_groups != 1:\n                    # for MHA this is a no-op\n                    ak = ak.repeat_interleave(q_per_kv, dim=2)\n                    av = av.repeat_interleave(q_per_kv, dim=2)\n                ak = ak.view(1, -1, aT, self.config.head_size)  # (1, nh_ak, aT, hs)\n                av = av.view(1, -1, aT, self.config.head_size)  # (1, nh_av, aT, hs)\n                adapter_kv_cache = (ak, av)\n\n            amask = torch.ones(T, aT, dtype=torch.bool, device=x.device)\n            ay = self.scaled_dot_product_attention(q, ak, av, amask)\n            y = y + self.gating_factor * ay\n\n        y = y.reshape(B, T, C)  # re-assemble all head outputs side by side\n\n        # output projection\n        y = self.proj(y)\n\n        return y, kv_cache, adapter_kv_cache\n\n    def reset_parameters(self) -> None:\n        torch.nn.init.zeros_(self.gating_factor)\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base checkpoints.\"\"\"\n        mapping = {\n            \"attn.weight\": \"attn.linear.weight\",\n            \"attn.bias\": \"attn.linear.bias\",\n            \"proj.weight\": \"proj.linear.weight\",\n            \"proj.bias\": \"proj.linear.bias\",\n        }\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n        # For compatibility with older checkpoints\n        if (key := prefix + \"gating_factor\") in state_dict and state_dict[key].size(1) == self.config.n_head:\n            state_dict[key] = state_dict[key].permute(0, 2, 1, 3)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\nclass GptNeoxMLP(lit_gpt.model.GptNeoxMLP):\n    def __init__(self, config: Config) -> None:\n        nn.Module.__init__(self)\n        self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)\n        self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base checkpoints.\"\"\"\n        mapping = {\n            \"fc.weight\": \"fc.linear.weight\",\n            \"fc.bias\": \"fc.linear.bias\",\n            \"proj.weight\": \"proj.linear.weight\",\n            \"proj.bias\": \"proj.linear.bias\",\n        }\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\nclass LLaMAMLP(lit_gpt.model.LLaMAMLP):\n    def __init__(self, config: Config) -> None:\n        nn.Module.__init__(self)\n        self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)\n        self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)\n        self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base checkpoints.\"\"\"\n        mapping = {\n            \"fc_1.weight\": \"fc_1.linear.weight\",\n            \"fc_1.bias\": \"fc_1.linear.bias\",\n            \"fc_2.weight\": \"fc_2.linear.weight\",\n            \"fc_2.bias\": \"fc_2.linear.bias\",\n            \"proj.weight\": \"proj.linear.weight\",\n            \"proj.bias\": \"proj.linear.bias\",\n        }\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\ndef mark_only_adapter_v2_as_trainable(model: GPT) -> None:\n    \"\"\"Sets requires_grad=False for all non-adapter weights\"\"\"\n    for name, param in model.named_parameters():\n        param.requires_grad = adapter_filter(name, param)\n"
  },
  {
    "path": "lit_gpt/config.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom dataclasses import dataclass\nfrom typing import Any, Literal, Optional, Type\n\nimport torch\nfrom typing_extensions import Self\n\nimport lit_gpt.model\nfrom lit_gpt.utils import find_multiple\n\n\n@dataclass\nclass Config:\n    org: str = \"Lightning-AI\"\n    name: str = \"lit-GPT\"\n    block_size: int = 4096\n    vocab_size: int = 50254\n    padding_multiple: int = 512\n    padded_vocab_size: Optional[int] = None\n    n_layer: int = 16\n    n_head: int = 32\n    n_embd: int = 4096\n    rotary_percentage: float = 0.25\n    parallel_residual: bool = True\n    bias: bool = True\n    # to use multi-head attention (MHA), set this to `n_head` (default)\n    # to use multi-query attention (MQA), set this to 1\n    # to use grouped-query attention (GQA), set this to a value in between\n    # Example with `n_head=4`\n    # ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐\n    # │ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │\n    # └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘\n    #   │    │    │    │         │        │                 │\n    # ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐\n    # │ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │\n    # └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘\n    #   │    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐\n    # ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐\n    # │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │\n    # └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘\n    # ◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶\n    #         MHA                    GQA                   MQA\n    #   n_query_groups=4       n_query_groups=2      n_query_groups=1\n    #\n    # credit https://arxiv.org/pdf/2305.13245.pdf\n    n_query_groups: Optional[int] = None\n    shared_attention_norm: bool = False\n    _norm_class: Literal[\"LayerNorm\", \"RMSNorm\"] = \"LayerNorm\"\n    norm_eps: float = 1e-5\n    _mlp_class: Literal[\"GptNeoxMLP\", \"LLaMAMLP\"] = \"GptNeoxMLP\"\n    intermediate_size: Optional[int] = None\n    condense_ratio: int = 1\n    \n    #flash attention config\n    enable_flash_attn: bool=False\n    flash_attn_dtype: torch.dtype=torch.bfloat16\n\n    def __post_init__(self):\n        # error checking\n        assert self.n_embd % self.n_head == 0\n        # vocab size should be a power of 2 to be optimal on hardware. compute the closest value\n        if self.padded_vocab_size is None:\n            self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple)\n        # compute the number of query groups\n        if self.n_query_groups is not None:\n            assert self.n_head % self.n_query_groups == 0\n        else:\n            self.n_query_groups = self.n_head\n        # compute the intermediate size for MLP if not set\n        if self.intermediate_size is None:\n            if self._mlp_class == \"LLaMAMLP\":\n                raise ValueError(\"The config needs to set the `intermediate_size`\")\n            self.intermediate_size = 4 * self.n_embd\n\n    @property\n    def head_size(self) -> int:\n        return self.n_embd // self.n_head\n\n    @classmethod\n    def from_name(cls, name: str, **kwargs: Any) -> Self:\n        conf_dict = name_to_config[name].copy()\n        conf_dict.update(kwargs)\n        return cls(**conf_dict)\n\n    @property\n    def mlp_class(self) -> Type:\n        # `self._mlp_class` cannot be the type to keep the config json serializable\n        return getattr(lit_gpt.model, self._mlp_class)\n\n    @property\n    def norm_class(self) -> Type:\n        # `self._norm_class` cannot be the type to keep the config json serializable\n        if self._norm_class == \"RMSNorm\":\n            from lit_gpt.rmsnorm import RMSNorm\n\n            return RMSNorm\n        elif self._norm_class == \"FusedRMSNorm\":\n            from lit_gpt.rmsnorm import FusedRMSNorm\n            return FusedRMSNorm\n        return getattr(torch.nn, self._norm_class)\n\n\n########################\n# Stability AI StableLM\n########################\nconfigs = [\n    # https://huggingface.co/stabilityai/stablelm-base-alpha-3b/blob/main/config.json\n    dict(org=\"stabilityai\", name=\"stablelm-base-alpha-3b\", padding_multiple=512),\n    # https://huggingface.co/stabilityai/stablelm-base-alpha-7b/blob/main/config.json\n    dict(org=\"stabilityai\", name=\"stablelm-base-alpha-7b\", n_head=48, n_embd=6144, padding_multiple=256),\n    # https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b/blob/main/config.json\n    dict(org=\"stabilityai\", name=\"stablelm-tuned-alpha-3b\", n_head=32, padding_multiple=512),\n    # https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b/blob/main/config.json\n    dict(org=\"stabilityai\", name=\"stablelm-tuned-alpha-7b\", n_head=48, n_embd=6144, padding_multiple=256),\n]\n\n####################\n# EleutherAI Pythia\n####################\npythia = [\n    # https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json\n    dict(org=\"EleutherAI\", name=\"pythia-70m\", block_size=2048, n_layer=6, n_embd=512, n_head=8, padding_multiple=128),\n    # https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json\n    dict(\n        org=\"EleutherAI\", name=\"pythia-160m\", block_size=2048, n_layer=12, n_embd=768, n_head=12, padding_multiple=128\n    ),\n    # https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json\n    dict(\n        org=\"EleutherAI\", name=\"pythia-410m\", block_size=2048, n_layer=24, n_embd=1024, n_head=16, padding_multiple=12\n    ),\n    # https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json\n    dict(org=\"EleutherAI\", name=\"pythia-1b\", block_size=2048, n_layer=16, n_embd=2048, n_head=8, padding_multiple=128),\n    # https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json\n    dict(\n        org=\"EleutherAI\", name=\"pythia-1.4b\", block_size=2048, n_layer=24, n_embd=2048, n_head=16, padding_multiple=128\n    ),\n    # https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json\n    dict(\n        org=\"EleutherAI\", name=\"pythia-2.8b\", block_size=2048, n_layer=32, n_embd=2560, n_head=32, padding_multiple=128\n    ),\n    # https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json\n    dict(\n        org=\"EleutherAI\", name=\"pythia-6.9b\", block_size=2048, n_layer=32, n_embd=4096, n_head=32, padding_multiple=256\n    ),\n    # https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json\n    dict(\n        org=\"EleutherAI\", name=\"pythia-12b\", block_size=2048, n_layer=36, n_embd=5120, n_head=40, padding_multiple=512\n    ),\n]\nconfigs.extend(pythia)\nfor c in pythia:\n    copy = c.copy()\n    copy[\"name\"] = f\"{c['name']}-deduped\"\n    configs.append(copy)\n\n\n####################################\n# togethercomputer RedPajama INCITE\n####################################\nredpajama_incite = [\n    # https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1/blob/main/config.json\n    dict(\n        org=\"togethercomputer\",\n        name=\"RedPajama-INCITE-{}-3B-v1\",\n        block_size=2048,\n        n_layer=32,\n        n_embd=2560,\n        n_head=32,\n        padding_multiple=256,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n    ),\n    # https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Base/blob/main/config.json\n    dict(\n        org=\"togethercomputer\",\n        name=\"RedPajama-INCITE-7B-{}\",\n        block_size=2048,\n        n_layer=32,\n        n_embd=4096,\n        n_head=32,\n        padding_multiple=256,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n    ),\n    # this redirects to the checkpoint above. kept for those who had the old weights already downloaded\n    dict(\n        org=\"togethercomputer\",\n        name=\"RedPajama-INCITE-{}-7B-v0.1\",\n        block_size=2048,\n        n_layer=32,\n        n_embd=4096,\n        n_head=32,\n        padding_multiple=256,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n    ),\n]\nfor c in redpajama_incite:\n    for kind in (\"Base\", \"Chat\", \"Instruct\"):\n        copy = c.copy()\n        copy[\"name\"] = c[\"name\"].format(kind)\n        configs.append(copy)\n\n\n#################\n# TII UAE Falcon\n#################\nfalcon = [\n    # https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json\n    dict(\n        org=\"tiiuae\",\n        name=\"falcon-7b{}\",\n        block_size=2048,\n        padded_vocab_size=65024,\n        n_layer=32,\n        n_head=71,\n        n_embd=4544,\n        rotary_percentage=1.0,\n        parallel_residual=True,\n        n_query_groups=1,\n        bias=False,\n        # this is not in the config, but in the original model implementation, only for this config\n        shared_attention_norm=True,\n    ),\n    # https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json\n    dict(\n        org=\"tiiuae\",\n        name=\"falcon-40b{}\",\n        block_size=2048,\n        padded_vocab_size=65024,\n        n_layer=60,\n        n_head=128,\n        n_embd=8192,\n        rotary_percentage=1.0,\n        parallel_residual=True,\n        n_query_groups=8,\n        bias=False,\n    ),\n]\nfor c in falcon:\n    for kind in (\"\", \"-instruct\"):\n        copy = c.copy()\n        copy[\"name\"] = c[\"name\"].format(kind)\n        configs.append(copy)\n\n\n#############################\n# StatNLP Research\n#############################\ntiny_LLaMA = [\n     \n    # https://twitter.com/cwolferesearch/status/1691929174175264858\n    dict(\n        org=\"StatNLP-research\",\n        name=\"original_tiny_LLaMA_1b\",\n        block_size=2048,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=22,\n        n_head=32,\n        n_embd=2048,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"FusedRMSNorm\",\n        norm_eps=1e-5, #Llama 2 use 1e-5. Llama 1 use 1e-6\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=5632,\n        n_query_groups=4,\n    ),\n    dict(\n        org=\"AMD-research\",\n        name=\"tiny_LLaMA_135M\",\n        block_size=4096,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=12,\n        n_head=12,\n        n_embd=768,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"FusedRMSNorm\",\n        norm_eps=1e-5,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=2048,\n        n_query_groups=12,\n    ),\n\n    dict(\n        org=\"AMD-research\",\n        name=\"tiny_LLaMA_135M_2k\",\n        block_size=2048,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=12,\n        n_head=12,\n        n_embd=768,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"FusedRMSNorm\",\n        norm_eps=1e-5,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=2048,\n        n_query_groups=12,\n    ),\n\n    dict(\n        org=\"AMD-research\",\n        name=\"tiny_Qwen_315M_2k\",\n        block_size=2048,\n        vocab_size=152000,\n        padding_multiple=64,\n        n_layer=12,\n        n_head=12,\n        n_embd=768,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"FusedRMSNorm\",\n        norm_eps=1e-5,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=2048,\n        n_query_groups=12,\n    ),\n    dict(\n        org=\"AMD-research\",\n        name=\"tiny_Qwen_315M_4k\",\n        block_size=4096,\n        vocab_size=152000,\n        padding_multiple=64,\n        n_layer=12,\n        n_head=12,\n        n_embd=768,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"FusedRMSNorm\",\n        norm_eps=1e-5,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=2048,\n        n_query_groups=12,\n    ),\n\n    dict(\n        org=\"StatNLP-research\",\n        name=\"tiny_LLaMA_120M\",\n        block_size=2048,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=12,\n        n_head=12,\n        n_embd=768,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"FusedRMSNorm\",\n        norm_eps=1e-5,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=2048,\n        n_query_groups=1,\n    ),\n    dict(\n        org=\"StatNLP-research\",\n        name=\"code_tiny_LLaMA_1b\",\n        block_size=8192,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=22,\n        n_head=32,\n        n_embd=2048,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"FusedRMSNorm\",\n        norm_eps=1e-5, #Llama 2 use 1e-5. Llama 1 use 1e-6\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=5632,\n        n_query_groups=4,\n        condense_ratio= 4\n    ),\n]\nconfigs.extend(tiny_LLaMA)\n\n\n#############################\n# OpenLM Research Open LLaMA\n#############################\nopen_LLaMA = [\n    # https://huggingface.co/openlm-research/open_llama_3b/blob/main/config.json\n    dict(\n        org=\"openlm-research\",\n        name=\"open_llama_3b\",\n        block_size=2048,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=26,\n        n_head=32,\n        n_embd=3200,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-6,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=8640,\n    ),\n    # https://huggingface.co/openlm-research/open_llama_7b/blob/main/config.json\n    dict(\n        org=\"openlm-research\",\n        name=\"open_llama_7b\",\n        block_size=2048,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=32,\n        n_head=32,\n        n_embd=4096,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-6,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=11008,\n    ),\n    # https://huggingface.co/openlm-research/open_llama_13b/blob/main/config.json\n    dict(\n        org=\"openlm-research\",\n        name=\"open_llama_13b\",\n        block_size=2048,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=40,\n        n_head=40,\n        n_embd=5120,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-6,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=13824,\n    ),\n]\nconfigs.extend(open_LLaMA)\n\n\n###############\n# LMSYS Vicuna\n###############\nvicuna = [\n    # https://huggingface.co/lmsys/vicuna-7b-v1.3/blob/main/config.json\n    dict(\n        org=\"lmsys\",\n        name=\"vicuna-7b-v1.3\",\n        block_size=2048,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=32,\n        n_head=32,\n        n_embd=4096,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-6,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=11008,\n    ),\n    # https://huggingface.co/lmsys/vicuna-13b-v1.3/blob/main/config.json\n    dict(\n        org=\"lmsys\",\n        name=\"vicuna-13b-v1.3\",\n        block_size=2048,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=40,\n        n_head=40,\n        n_embd=5120,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-6,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=13824,\n    ),\n    # https://huggingface.co/lmsys/vicuna-33b-v1.3/blob/main/config.json\n    dict(\n        org=\"lmsys\",\n        name=\"vicuna-33b-v1.3\",\n        block_size=2048,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=60,\n        n_head=52,\n        n_embd=6656,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-6,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=17920,\n    ),\n    dict(\n        org=\"lmsys\",\n        name=\"vicuna-7b-v1.5\",\n        block_size=4096,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=32,\n        n_head=32,\n        n_embd=4096,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-5,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=11008,\n    ),\n    dict(\n        org=\"lmsys\",\n        name=\"vicuna-7b-v1.5-16k\",\n        block_size=16384,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=32,\n        n_head=32,\n        n_embd=4096,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-5,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=11008,\n        condense_ratio=4,\n    ),\n    dict(\n        org=\"lmsys\",\n        name=\"vicuna-13b-v1.5\",\n        block_size=4096,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=40,\n        n_head=40,\n        n_embd=5120,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-5,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=13824,\n    ),\n    dict(\n        org=\"lmsys\",\n        name=\"vicuna-13b-v1.5-16k\",\n        block_size=16384,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=40,\n        n_head=40,\n        n_embd=5120,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-5,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=13824,\n        condense_ratio=4,\n    ),\n]\nconfigs.extend(vicuna)\n\n\n#################\n# LMSYS LongChat\n#################\nlong_chat = [\n    # https://huggingface.co/lmsys/longchat-7b-16k/blob/main/config.json\n    dict(\n        org=\"lmsys\",\n        name=\"longchat-7b-16k\",\n        block_size=16384,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=32,\n        n_head=32,\n        n_embd=4096,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-6,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=11008,\n        condense_ratio=8,\n    ),\n    # https://huggingface.co/lmsys/longchat-13b-16k/blob/main/config.json\n    dict(\n        org=\"lmsys\",\n        name=\"longchat-13b-16k\",\n        block_size=16384,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=40,\n        n_head=40,\n        n_embd=5120,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-6,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=13824,\n        condense_ratio=8,\n    ),\n]\nconfigs.extend(long_chat)\n\n\n######################\n# NousResearch Hermes\n######################\nnous_research = [\n    # https://huggingface.co/NousResearch/Nous-Hermes-13B/blob/main/config.json\n    dict(\n        org=\"NousResearch\",\n        name=\"Nous-Hermes-13b\",\n        block_size=2048,\n        padded_vocab_size=32001,\n        n_layer=40,\n        n_head=40,\n        n_embd=5120,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-6,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=13824,\n    )\n]\nconfigs.extend(nous_research)\n\n\n###############\n# Meta LLaMA 2\n###############\nllama_2 = [\n    # https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json\n    dict(\n        org=\"meta-llama\",\n        name=\"Llama-2-7b{}-hf\",\n        block_size=4096,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=32,\n        n_head=32,\n        n_embd=4096,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-5,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=11008,\n    ),\n    dict(\n        org=\"meta-llama\",\n        name=\"CodeLlama-2-7b-hf\",\n        block_size=4096,\n        vocab_size=32016,\n        padded_vocab_size=32016,\n        padding_multiple=64,\n        n_layer=32,\n        n_head=32,\n        n_embd=4096,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-5,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=11008,\n    ),\n    # https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json\n    dict(\n        org=\"meta-llama\",\n        name=\"Llama-2-13b{}-hf\",\n        block_size=4096,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=40,\n        n_head=40,\n        n_embd=5120,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-5,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=13824,\n    ),\n    # https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json\n    dict(\n        org=\"meta-llama\",\n        name=\"Llama-2-70b{}-hf\",\n        block_size=4096,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=80,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-5,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=28672,\n    ),\n]\nfor c in llama_2:\n    for kind in (\"\", \"-chat\"):\n        copy = c.copy()\n        copy[\"name\"] = c[\"name\"].format(kind)\n        configs.append(copy)\n\n\n##########################\n# Stability AI FreeWilly2\n##########################\nfreewilly_2 = [\n    # https://huggingface.co/stabilityai/FreeWilly2/blob/main/config.json\n    dict(\n        org=\"stabilityai\",\n        name=\"FreeWilly2\",\n        block_size=4096,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=80,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        _norm_class=\"RMSNorm\",\n        norm_eps=1e-5,\n        _mlp_class=\"LLaMAMLP\",\n        intermediate_size=28672,\n    )\n]\nconfigs.extend(freewilly_2)\n\n\nname_to_config = {config[\"name\"]: config for config in configs}\n"
  },
  {
    "path": "lit_gpt/fused_cross_entropy.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\n# Copyright (c) 2023, Tri Dao.\n\nimport torch\nimport torch.nn as nn\nimport xentropy_cuda_lib\n\n# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for\n# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent\n# version of PyTorch. The following 2 lines are for backward compatibility with\n# older PyTorch.\nif \"all_gather_into_tensor\" not in dir(torch.distributed):\n    torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base\n\n\nclass SoftmaxCrossEntropyLossFn(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        logits,\n        labels,\n        smoothing=0.0,\n        ignored_index=-100,\n        inplace_backward=False,\n        process_group=None,\n    ):\n        \"\"\"\n        logits: (batch, vocab_size)\n        labels: (batch,)\n        If process_group is not None, we're doing Tensor Parallel: each process is responsible for\n        one part of the vocab. The loss needs to be aggregated across processes.\n        \"\"\"\n        batch, vocab_size = logits.shape\n        assert labels.shape == (batch,)\n        world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)\n        ctx.total_classes = world_size * vocab_size\n\n        if world_size == 1:\n            losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)\n            losses.masked_fill_(labels == ignored_index, 0)\n            labels_local = labels\n        else:\n            rank = torch.distributed.get_rank(process_group)\n            vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size\n\n            # Create a mask of valid vocab ids (1 means it needs to be masked).\n            labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)\n            ignored_mask = labels == ignored_index\n            labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)\n\n            # For tensor parallel cross entropy with smoothing, we want to pass in the total number\n            # of classes so that smoothing can be applied correctly. If total_classes=-1, use the\n            # last dimension of the input tensor.\n            losses, lse_local = xentropy_cuda_lib.forward(\n                logits, labels_local, smoothing, world_size * vocab_size\n            )\n            assert lse_local.shape == (batch,)\n            assert losses.shape == (batch,)\n            losses.masked_fill_(ignored_mask, 0)\n            # For labels == ignored_index, the loss is always 0.\n            # If there's no smoothing, if labels are in the vocab of this partition, losses contains\n            # lse_local - predicted logit, and 0 otherwise.\n            # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains\n            # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)\n            # For labels not in the vocab of this partition, losses contains\n            # 0.1 * (lse_local - sum logit / total_classes).\n\n            lse_allgather = torch.empty(\n                world_size, batch, dtype=lse_local.dtype, device=lse_local.device\n            )\n            torch.distributed.all_gather_into_tensor(\n                lse_allgather, lse_local.contiguous(), group=process_group\n            )\n            handle_losses = torch.distributed.all_reduce(\n                losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True\n            )\n            lse = torch.logsumexp(lse_allgather, dim=0)\n            # If there's no smoothing, the total losses are lse_local - predicted_logit,\n            # we just have to subtract the lse_local and add the lse (global).\n            # If there's smoothing=0.1, the total losses are\n            # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)\n            # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).\n            rank_per_sample = torch.div(labels, vocab_size, rounding_mode=\"floor\")\n            lse_local = lse_allgather[\n                rank_per_sample, torch.arange(batch, device=lse_allgather.device)\n            ]\n\n            handle_losses.wait()\n            if smoothing == 0.0:\n                losses += lse - lse_local\n            else:\n                losses += (1 - smoothing) * (lse - lse_local) + smoothing * (\n                    lse - lse_allgather.sum(dim=0)\n                )\n            losses.masked_fill_(ignored_mask, 0)\n\n        ctx.save_for_backward(logits, lse, labels_local)\n        ctx.smoothing = smoothing\n        ctx.ignored_index = ignored_index\n        ctx.inplace_backward = inplace_backward\n        return losses\n\n    @staticmethod\n    def backward(ctx, grad_loss):\n        logits, lse, labels = ctx.saved_tensors\n        grad_loss = grad_loss.contiguous()\n        grad_loss.masked_fill_(labels == ctx.ignored_index, 0)\n        grad_logits = xentropy_cuda_lib.backward(\n            grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes\n        )\n        return grad_logits, None, None, None, None, None, None\n\n\nclass FusedCrossEntropyLoss(nn.Module):\n    def __init__(\n        self,\n        ignore_index=-100,\n        reduction=\"mean\",\n        label_smoothing=0.0,\n        inplace_backward=True,\n        process_group=None,\n    ):\n        super().__init__()\n        if reduction not in [\"mean\", \"none\"]:\n            raise NotImplementedError(\"Only support reduction = 'mean' or 'none'\")\n        self.ignore_index = ignore_index\n        self.reduction = reduction\n        self.label_smoothing = label_smoothing\n        self.inplace_backward = inplace_backward\n        self.process_group = process_group\n\n    def forward(self, input, target):\n        assert input.is_cuda and target.is_cuda\n        # SoftmaxCrossEntropyLoss implicitly casts to float\n        if len(input.shape) == 3:\n            input = input.view(-1, input.size(-1))\n            target = target.view(-1)\n        loss = SoftmaxCrossEntropyLossFn.apply(\n            input,\n            target,\n            self.label_smoothing,\n            self.ignore_index,\n            self.inplace_backward,\n            self.process_group,\n        )\n        if self.reduction == \"mean\":\n            return loss.sum() / (target != self.ignore_index).sum()\n        else:\n            return loss"
  },
  {
    "path": "lit_gpt/fused_rotary_embedding.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\n# Copyright (c) 2023, Tri Dao.\n\nimport math\nfrom typing import Optional, Tuple\n\nimport rotary_emb\nimport torch\nfrom einops import rearrange, repeat\n\nclass ApplyRotaryEmb(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, x, cos, sin, interleaved=False, inplace=False):\n        \"\"\"\n            x: (batch_size, seqlen, nheads, headdim)\n            cos, sin: (seqlen, rotary_dim / 2)\n            interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead\n                of 1st half and 2nd half (GPT-NeoX style).\n        rotary_dim must be <= headdim\n        Apply rotary embedding to the first rotary_dim of x.\n        \"\"\"\n        batch, seqlen, nheads, headdim = x.shape\n        rotary_seqlen, rotary_dim = cos.shape\n        rotary_dim *= 2\n        assert rotary_dim <= headdim\n        assert seqlen <= rotary_seqlen\n        assert sin.shape == (rotary_seqlen, rotary_dim // 2)\n        x_ro = x[..., :rotary_dim]\n        x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])\n        out = torch.empty_like(x) if not inplace else x\n        out_ro = out[..., :rotary_dim]\n        if inplace:\n            o1, o2 = x1, x2\n        else:\n            o1, o2 = (\n                out_ro.chunk(2, dim=-1)\n                if not interleaved\n                else (out_ro[..., ::2], out_ro[..., 1::2])\n            )\n        rotary_emb.apply_rotary(\n            x1,\n            x2,\n            rearrange(cos[:seqlen], \"s d -> s 1 d\"),\n            rearrange(sin[:seqlen], \"s d -> s 1 d\"),\n            o1,\n            o2,\n            False,\n        )\n        if not inplace and rotary_dim < headdim:\n            out[..., rotary_dim:].copy_(x[..., rotary_dim:])\n        ctx.save_for_backward(cos, sin)\n        ctx.interleaved = interleaved\n        ctx.inplace = inplace\n        return out if not inplace else x\n\n    @staticmethod\n    def backward(ctx, do):\n        cos, sin = ctx.saved_tensors\n        _, seqlen, _, headdim = do.shape\n        rotary_dim = cos.shape[-1]\n        rotary_dim *= 2\n        inplace = ctx.inplace\n        do_ro = do[..., :rotary_dim]\n        do1, do2 = (\n            do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2])\n        )\n        dx = torch.empty_like(do) if not inplace else do\n        if inplace:\n            dx1, dx2 = do1, do2\n        else:\n            dx_ro = dx[..., :rotary_dim]\n            dx1, dx2 = (\n                dx_ro.chunk(2, dim=-1)\n                if not ctx.interleaved\n                else (dx_ro[..., ::2], dx_ro[..., 1::2])\n            )\n        rotary_emb.apply_rotary(\n            do1,\n            do2,\n            rearrange(cos[:seqlen], \"s d -> s 1 d\"),\n            rearrange(sin[:seqlen], \"s d -> s 1 d\"),\n            dx1,\n            dx2,\n            True,\n        )\n        if not inplace and rotary_dim < headdim:\n            dx[..., rotary_dim:].copy_(do[..., rotary_dim:])\n        return dx, None, None, None, None\n\n\napply_rotary_emb_func = ApplyRotaryEmb.apply\n\n"
  },
  {
    "path": "lit_gpt/lora.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\n# Derived from https://github.com/microsoft/LoRA\n#  ------------------------------------------------------------------------------------------\n#  Copyright (c) Microsoft Corporation. All rights reserved.\n#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.\n#  ------------------------------------------------------------------------------------------\n\nr\"\"\"\n    Low Ranking Adaptation for LLMs scheme.\n\n             ┌───────────────────┐\n             ┆         h         ┆\n             └───────────────────┘\n                       ▲\n                       |\n                       +\n                    /     \\\n    ┌─────────────────┐    ╭───────────────╮     Matrix initialization:\n    ┆                 ┆     \\      B      /      B = 0\n    ┆   pretrained    ┆      \\    r*d    /       A = N(0, sigma^2)\n    ┆    weights      ┆       ╰─────────╯\n    ┆                 ┆       |    r    |        r - rank\n    ┆   W e R^(d*d)   ┆       | ◀─────▶ |\n    ┆                 ┆       ╭─────────╮\n    └─────────────────┘      /     A     \\\n              ▲             /     d*r     \\\n               \\           ╰───────────────╯\n                \\                ▲\n                 \\              /\n                  \\            /\n             ┌───────────────────┐\n             ┆         x         ┆\n             └───────────────────┘\n\nWith LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d,\nwe can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates\nfor the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of\ncourse) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen\npretrained weights and thus fine-tune the model.\n\nThe goal of this approach is to move weight updates into a separate matrix which is decomposed with\ntwo matrices of a lower rank.\n\"\"\"\n\nimport math\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Type, Union\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nfrom typing_extensions import Self\n\nimport lit_gpt\nfrom lit_gpt.config import Config as BaseConfig\nfrom lit_gpt.model import GPT as BaseModel\nfrom lit_gpt.model import Block as BaseBlock\nfrom lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention\nfrom lit_gpt.model import KVCache, RoPECache\nfrom lit_gpt.utils import map_old_state_dict_weights\n\n\nclass LoRALayer(nn.Module):\n    def __init__(self, r: int, lora_alpha: int, lora_dropout: float):\n        \"\"\"Store LoRA specific attributes in a class.\n\n        Args:\n            r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of\n                the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)\n            lora_alpha: alpha is needed for scaling updates as alpha/r\n                \"This scaling helps to reduce the need to retune hyperparameters when we vary r\"\n                https://arxiv.org/pdf/2106.09685.pdf (section 4.1)\n            lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)\n        \"\"\"\n        super().__init__()\n        assert r >= 0\n        self.r = r\n        self.lora_alpha = lora_alpha\n        # Optional dropout\n        if lora_dropout > 0.0:\n            self.lora_dropout = nn.Dropout(p=lora_dropout)\n        else:\n            self.lora_dropout = lambda x: x\n        # Mark the weight as unmerged\n        self.merged = False\n\n\nclass LoRALinear(LoRALayer):\n    # LoRA implemented in a dense layer\n    def __init__(\n        self,\n        # ↓ this part is for pretrained weights\n        in_features: int,\n        out_features: int,\n        # ↓ the remaining part is for LoRA\n        r: int = 0,\n        lora_alpha: int = 1,\n        lora_dropout: float = 0.0,\n        **kwargs,\n    ):\n        \"\"\"LoRA wrapper around linear class.\n\n        This class has three weight matrices:\n            1. Pretrained weights are stored as `self.linear.weight`\n            2. LoRA A matrix as `self.lora_A`\n            3. LoRA B matrix as `self.lora_B`\n        Only LoRA's A and B matrices are updated, pretrained weights stay frozen.\n\n        Args:\n            in_features: number of input features of the pretrained weights\n            out_features: number of output features of the pretrained weights\n            r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of\n                the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)\n            lora_alpha: alpha is needed for scaling updates as alpha/r\n                \"This scaling helps to reduce the need to retune hyperparameters when we vary r\"\n                https://arxiv.org/pdf/2106.09685.pdf (section 4.1)\n            lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)\n        \"\"\"\n        super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)\n        self.linear = torch.nn.Linear(in_features, out_features, **kwargs)\n\n        # Actual trainable parameters\n        if r > 0:\n            self.lora_A = nn.Parameter(self.linear.weight.new_zeros((r, in_features)))\n            self.lora_B = nn.Parameter(self.linear.weight.new_zeros((out_features, r)))\n            self.scaling = self.lora_alpha / self.r\n            self.reset_parameters()\n\n    def reset_parameters(self):\n        \"\"\"Reset all the weights, even including pretrained ones.\"\"\"\n        if hasattr(self, \"lora_A\"):\n            # initialize A the same way as the default for nn.Linear and B to zero\n            # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314\n            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))\n            nn.init.zeros_(self.lora_B)\n\n    def merge(self):\n        \"\"\"Merges the LoRA weights into the full-rank weights (W = W + delta_W).\"\"\"\n        if self.r > 0 and not self.merged:\n            # Merge the weights and mark it\n            self.linear.weight.data += (self.lora_B @ self.lora_A) * self.scaling\n            self.merged = True\n\n    def forward(self, x: torch.Tensor):\n        # if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass;\n        # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights\n        pretrained = self.linear(x)\n        if self.r == 0 or self.merged:\n            return pretrained\n        lora = (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling\n        return pretrained + lora\n\n\nclass LoRAQKVLinear(LoRALinear):\n    # LoRA implemented in a dense layer\n    def __init__(\n        self,\n        # ↓ this part is for pretrained weights\n        in_features: int,\n        out_features: int,\n        # ↓ the remaining part is for LoRA\n        n_head: int,\n        n_query_groups: int,\n        r: int = 0,\n        lora_alpha: int = 1,\n        lora_dropout: float = 0.0,\n        enable_lora: Union[bool, Tuple[bool, bool, bool]] = False,\n        **kwargs,\n    ):\n        \"\"\"LoRA wrapper around linear class that is used for calculation of q, k and v matrices.\n\n        This class has three weight matrices:\n            1. Pretrained weights are stored as `self.linear.weight`\n            2. LoRA A matrix as `self.lora_A`\n            3. LoRA B matrix as `self.lora_B`\n        Only LoRA's A and B matrices are updated, pretrained weights stay frozen.\n\n        Args:\n            in_features: number of input features of the pretrained weights\n            out_features: number of output features of the pretrained weights\n            n_head: number of attention heads\n            n_query_groups: number of query groups (see diagram in `lit_gpt/config.py`)\n            r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of\n                the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)\n            lora_alpha: alpha is needed for scaling updates as alpha/r\n                \"This scaling helps to reduce the need to retune hyperparameters when we vary r\"\n                https://arxiv.org/pdf/2106.09685.pdf (section 4.1)\n            lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)\n            enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we\n                don't want to apply LoRA we can set it as False. For example if we want to apply LoRA only to `query`\n                and `value` but keep `key` without weight updates we should pass `[True, False, True]`\n        \"\"\"\n        super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)\n        self.linear = torch.nn.Linear(in_features, out_features, **kwargs)\n        self.n_head = n_head\n        self.n_query_groups = n_query_groups\n        if isinstance(enable_lora, bool):\n            enable_lora = [enable_lora] * 3\n        assert len(enable_lora) == 3\n        self.enable_lora = enable_lora\n\n        # Actual trainable parameters\n        # To better understand initialization let's imagine that we have such parameters:\n        # ⚬ in_features: 128 (embeddings_size)\n        # ⚬ out_features: 384 (3 * embedding_size)\n        # ⚬ r: 2\n        # ⚬ enable_lora: [True, False, True]\n        if r > 0 and any(enable_lora):\n            self.lora_A = nn.Parameter(self.linear.weight.new_zeros((r * sum(enable_lora), in_features)))  # (4, 128)\n            enable_q, enable_k, enable_v = enable_lora\n            self.kv_embd_size = self.linear.in_features // (n_head // n_query_groups)\n            # qkv_shapes will be used to split a tensor with weights correctly\n            qkv_shapes = (\n                self.linear.in_features * enable_q,\n                self.kv_embd_size * enable_k,\n                self.kv_embd_size * enable_v,\n            )\n            self.qkv_shapes = [s for s in qkv_shapes if s]\n            self.lora_B = nn.Parameter(self.linear.weight.new_zeros(sum(self.qkv_shapes), r))  # (256, 2))\n            # Notes about shapes above\n            # - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;\n            # 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in\n            # F.linear function weights are automatically transposed. In addition conv1d requires channels to\n            # be before seq length\n            # - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is\n            # 128*2; 2 tells to have two channels per group for group convolution\n\n            # Scaling:\n            # This balances the pretrained model`s knowledge and the new task-specific adaptation\n            # https://lightning.ai/pages/community/tutorial/lora-llm/\n            # So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set\n            # alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can\n            # tune these values to your needs. This value can be even slightly greater than 1.0!\n            # https://github.com/cloneofsimo/lora\n            self.scaling = self.lora_alpha / self.r\n\n            # Compute the indices\n            # Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values,\n            # but not keys, then the weights update should be:\n            #\n            # [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],\n            #  [....................................],\n            #  [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]\n            #      ↑              ↑            ↑\n            # ________________________________________\n            # | query         | key       | value    |\n            # ----------------------------------------\n            self.lora_ind = []\n            if enable_q:\n                self.lora_ind.extend(range(0, self.linear.in_features))\n            if enable_k:\n                self.lora_ind.extend(range(self.linear.in_features, self.linear.in_features + self.kv_embd_size))\n            if enable_v:\n                self.lora_ind.extend(range(self.linear.in_features + self.kv_embd_size, self.linear.out_features))\n            self.reset_parameters()\n\n    def zero_pad(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Properly pad weight updates with zeros.\n\n        If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys,\n        then the weights update should be:\n\n        [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],\n         [....................................],\n         [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]\n            ↑              ↑            ↑\n        ________________________________________\n        | query         | key       | value    |\n        ----------------------------------------\n\n        Args:\n            x: tensor with weights update that will be padded with zeros if necessary\n\n        Returns:\n            A tensor with weight updates and zeros for deselected q, k or v\n        \"\"\"\n        # we need to do zero padding only if LoRA is disabled for one of QKV matrices\n        if all(self.enable_lora):\n            return x\n\n        # Let's image that:\n        # ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size)\n        # ⚬ embeddings_size: 128\n        # ⚬ self.linear.out_features: 384 (3 * embeddings_size)\n        # ⚬ enable_lora: [True, False, True]\n        # Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected\n        # embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but\n        # only for key updates (this is where self.lora_ind comes in handy)\n        # Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors\n        # for example when we want to merge/unmerge LoRA weights and pretrained weights\n        x = x.transpose(0, 1)\n        result = x.new_zeros((*x.shape[:-1], self.linear.out_features))  # (64, 64, 384)\n        result = result.view(-1, self.linear.out_features)  # (4096, 384)\n        result = result.index_copy(\n            1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes))\n        )  # (4096, 256)\n        return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1)  # (64, 64, 384)\n\n    def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:\n        \"\"\"An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries.\n\n        If the number of heads is equal to the number of query groups - grouped queries are disabled\n        (see scheme in `lit_gpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized\n        query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the\n        input and weight matrices will be splitted in equally sized parts and applied separately (like having multiple\n        conv layers side by side).\n\n        Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually,\n        apply each part of the weight matrix to the corresponding input's part and concatenate the result.\n\n        Args:\n            input: input matrix of shape (B, C, T)\n            weight: weight matrix of shape (C_output, rank, 1).\n                \"C_output\" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class).\n\n        Returns:\n            A tensor with a shape (B, C_output, T)\n\n        \"\"\"\n        if self.n_head == self.n_query_groups:\n            return F.conv1d(input, weight, groups=sum(self.enable_lora))  # (B, C_output, T)\n\n        # Notation:\n        # ⚬ N: number of enabled LoRA layers (self.enable_lora)\n        # ⚬ C_output': embeddings size for each LoRA layer (not equal in size)\n        # ⚬ r: rank of all LoRA layers (equal in size)\n\n        input_splitted = input.chunk(sum(self.enable_lora), dim=1)  # N * (B, C // N, T)\n        weight_splitted = weight.split(self.qkv_shapes)  # N * (C_output', r, 1)\n        return torch.cat(\n            [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1  # (B, C_output', T)\n        )  # (B, C_output, T)\n\n    def merge(self):\n        \"\"\"Merges the LoRA weights into the full-rank weights (W = W + delta_W).\"\"\"\n\n        # Let's assume that:\n        # ⚬ self.linear.weight.data: (384, 128) or (3 * embedding_size, embedding_size)\n        # ⚬ self.lora_A.data: (4, 128)\n        # ⚬ self.lora_B.data: (256, 2)\n        if self.r > 0 and any(self.enable_lora) and not self.merged:\n            delta_w = self.conv1d(\n                self.lora_A.data.unsqueeze(0),  # (4, 128) -> (1, 4, 128)\n                self.lora_B.data.unsqueeze(-1),  # (256, 2) -> (256, 2, 1)\n            ).squeeze(\n                0\n            )  # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)\n            # W = W + delta_W (merge)\n            self.linear.weight.data += self.zero_pad(delta_w * self.scaling)  # (256, 128) after zero_pad (384, 128)\n            self.merged = True\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Do the forward pass.\n\n        If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication.\n        If not, then multiply pretrained weights with input, apply LoRA on input and do summation.\n\n        Args:\n            x: input tensor of shape (batch_size, context_length, embedding_size)\n\n        Returns:\n            Output tensor of shape (batch_size, context_length, 3 * embedding_size)\n        \"\"\"\n\n        # Let's assume that:\n        # ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size)\n        # ⚬ self.linear.weight: (384, 128) or (3 * embedding_size, embedding_size)\n        # ⚬ self.lora_A.data: (4, 128)\n        # ⚬ self.lora_B.data: (256, 2)\n\n        # if weights are merged or LoRA is disabled (r <= 0 or all `enable_lora` are False) - it's only a regular nn.Linear forward pass;\n        # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights\n        pretrained = self.linear(x)\n        if self.r == 0 or not any(self.enable_lora) or self.merged:\n            return pretrained\n        after_A = F.linear(self.lora_dropout(x), self.lora_A)  # (64, 64, 128) @ (4, 128) -> (64, 64, 4)\n        # For F.conv1d:\n        # ⚬ input: input tensor of shape (mini-batch, in_channels, iW)\n        # ⚬ weight: filters of shape (out_channels, in_channels/groups, kW)\n        after_B = self.conv1d(\n            after_A.transpose(-2, -1),  # (64, 64, 4) -> (64, 4, 64)\n            self.lora_B.unsqueeze(-1),  # (256, 2) -> (256, 2, 1)\n        ).transpose(\n            -2, -1\n        )  # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)\n        lora = self.zero_pad(after_B) * self.scaling  # (64, 64, 256) after zero_pad (64, 64, 384)\n        return pretrained + lora\n\n\ndef mark_only_lora_as_trainable(model: nn.Module, bias: str = \"none\") -> None:\n    \"\"\"Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights.\n\n    Args:\n        model: model with LoRA layers\n        bias:\n            ``\"none\"``: all bias weights will be frozen,\n            ``\"lora_only\"``: only bias weight for LoRA layers will be unfrozen,\n            ``\"all\"``: all bias weights will be unfrozen.\n\n    Raises:\n        NotImplementedError: if `bias` not in [\"none\", \"lora_only\", \"all\"]\n    \"\"\"\n    # freeze all layers except LoRA's\n    for n, p in model.named_parameters():\n        if \"lora_\" not in n:\n            p.requires_grad = False\n\n    # depending on the `bias` value unfreeze bias weights\n    if bias == \"none\":\n        return\n    if bias == \"all\":\n        for n, p in model.named_parameters():\n            if \"bias\" in n:\n                p.requires_grad = True\n    elif bias == \"lora_only\":\n        for m in model.modules():\n            if isinstance(m, LoRALayer) and hasattr(m, \"bias\") and m.bias is not None:\n                m.bias.requires_grad = True\n    else:\n        raise NotImplementedError\n\n\ndef lora_filter(key: str, value: Any) -> bool:\n    return \"lora_\" in key\n\n\n@dataclass\nclass Config(BaseConfig):\n    \"\"\"\n    Args:\n        r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of\n            the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)\n        alpha: alpha is needed for scaling updates as alpha/r\n            \"This scaling helps to reduce the need to retune hyperparameters when we vary r\"\n            https://arxiv.org/pdf/2106.09685.pdf (section 4.1)\n        dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)\n        to_*: either apply LoRA to the specified weights or not\n    \"\"\"\n\n    r: int = 0\n    alpha: int = 1\n    dropout: float = 0.0\n    to_query: bool = False\n    to_key: bool = False\n    to_value: bool = False\n    to_projection: bool = False\n    to_mlp: bool = False\n    to_head: bool = False\n\n    @property\n    def mlp_class(self) -> Type:\n        return getattr(lit_gpt.lora, self._mlp_class)\n\n\nclass GPT(BaseModel):\n    def __init__(self, config: Config) -> None:\n        nn.Module.__init__(self)\n        assert config.padded_vocab_size is not None\n        self.config = config\n\n        self.lm_head = LoRALinear(\n            config.n_embd,\n            config.padded_vocab_size,\n            bias=False,\n            r=(config.r if config.to_head else 0),\n            lora_alpha=config.alpha,\n            lora_dropout=config.dropout,\n        )\n\n        self.transformer = nn.ModuleDict(\n            dict(\n                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),\n                h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),\n                ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),\n            )\n        )\n\n        self.rope_cache: Optional[RoPECache] = None\n        self.mask_cache: Optional[torch.Tensor] = None\n        self.kv_caches: List[KVCache] = []\n\n    def forward(\n        self,\n        idx: torch.Tensor,\n        max_seq_length: Optional[int] = None,\n        input_pos: Optional[torch.Tensor] = None,\n        lm_head_chunk_size: int = 0,\n    ) -> Union[torch.Tensor, List[torch.Tensor]]:\n        B, T = idx.size()\n        use_kv_cache = input_pos is not None\n\n        block_size = self.config.block_size\n        if max_seq_length is None:\n            max_seq_length = block_size\n        if use_kv_cache:  # not relevant otherwise\n            assert (\n                max_seq_length >= T\n            ), f\"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}\"\n        assert max_seq_length <= block_size, f\"Cannot attend to {max_seq_length}, block size is only {block_size}\"\n        assert block_size >= T, f\"Cannot forward sequence of length {T}, block size is only {block_size}\"\n\n        if self.rope_cache is None:\n            self.rope_cache = self.build_rope_cache(idx)  # 2 * (block_size, head_size * rotary_percentage)\n        # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask\n        # for the kv-cache support (only during inference), we only create it in that situation\n        # this will be resolved by https://github.com/pytorch/pytorch/issues/96099\n        if use_kv_cache and self.mask_cache is None:\n            self.mask_cache = self.build_mask_cache(idx)  # (1, 1, block_size, block_size)\n\n        cos, sin = self.rope_cache\n        if use_kv_cache:\n            cos = cos.index_select(0, input_pos)\n            sin = sin.index_select(0, input_pos)\n            mask = self.mask_cache.index_select(2, input_pos)\n            mask = mask[:, :, :, :max_seq_length]\n        else:\n            cos = cos[:T]\n            sin = sin[:T]\n            mask = None\n\n        # forward the model itself\n        x = self.transformer.wte(idx)  # token embeddings of shape (B, T, n_embd)\n\n        if not use_kv_cache:\n            for block in self.transformer.h:\n                x, *_ = block(x, (cos, sin), max_seq_length)\n        else:\n            self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1))\n            for i, block in enumerate(self.transformer.h):\n                x, self.kv_caches[i] = block(x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i])\n\n        x = self.transformer.ln_f(x)\n\n        if lm_head_chunk_size > 0:\n            # chunk the lm head logits to reduce the peak memory used by autograd\n            return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]\n        return self.lm_head(x)  # (B, T, vocab_size)\n\n    @classmethod\n    def from_name(cls, name: str, **kwargs: Any) -> Self:\n        return cls(Config.from_name(name, **kwargs))\n\n    def _init_weights(self, module: nn.Module) -> None:\n        \"\"\"Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.\"\"\"\n        super()._init_weights(module)\n        if isinstance(module, LoRALinear):\n            module.reset_parameters()\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base checkpoints.\"\"\"\n        mapping = {\"lm_head.weight\": \"lm_head.linear.weight\"}\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\nclass Block(BaseBlock):\n    def __init__(self, config: Config) -> None:\n        nn.Module.__init__(self)\n        self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)\n        self.attn = CausalSelfAttention(config)\n        if not config.shared_attention_norm:\n            self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)\n        self.mlp = config.mlp_class(config)\n\n        self.config = config\n\n\nclass CausalSelfAttention(BaseCausalSelfAttention):\n    def __init__(self, config: Config) -> None:\n        \"\"\"Causal self-attention with calculating qkv matrices with a single matrix* and Low Ranking Adaptation for\n        parameter-efficient fine-tuning.\n\n        *Instead of creating multiple heads and concatenating the result (in addition to creating separate matrices for\n        query, key and value for each head) we can do this in a single pass with a single weight matrix.\n        \"\"\"\n        # Skip the parent class __init__ altogether and replace it to avoid\n        # useless allocations\n        nn.Module.__init__(self)\n        shape = (config.n_head + 2 * config.n_query_groups) * config.head_size\n        # key, query, value projections for all heads, but in a batch\n        self.attn = LoRAQKVLinear(\n            in_features=config.n_embd,\n            out_features=shape,\n            r=config.r,\n            lora_alpha=config.alpha,\n            lora_dropout=config.dropout,\n            enable_lora=(config.to_query, config.to_key, config.to_value),\n            bias=config.bias,\n            # for MQA/GQA support\n            n_head=config.n_head,\n            n_query_groups=config.n_query_groups,\n        )\n        # output projection\n        self.proj = LoRALinear(\n            config.n_embd,\n            config.n_embd,\n            bias=config.bias,\n            r=(config.r if config.to_projection else 0),\n            lora_alpha=config.alpha,\n            lora_dropout=config.dropout,\n        )\n\n        self.config = config\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base checkpoints.\"\"\"\n        mapping = {\n            \"attn.weight\": \"attn.linear.weight\",\n            \"attn.bias\": \"attn.linear.bias\",\n            \"proj.weight\": \"proj.linear.weight\",\n            \"proj.bias\": \"proj.linear.bias\",\n        }\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\nclass GptNeoxMLP(lit_gpt.model.GptNeoxMLP):\n    def __init__(self, config: Config) -> None:\n        nn.Module.__init__(self)\n        self.fc = LoRALinear(\n            config.n_embd,\n            config.intermediate_size,\n            bias=config.bias,\n            r=(config.r if config.to_mlp else 0),\n            lora_alpha=config.alpha,\n            lora_dropout=config.dropout,\n        )\n        self.proj = LoRALinear(\n            config.intermediate_size,\n            config.n_embd,\n            bias=config.bias,\n            r=(config.r if config.to_mlp else 0),\n            lora_alpha=config.alpha,\n            lora_dropout=config.dropout,\n        )\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base checkpoints.\"\"\"\n        mapping = {\n            \"fc.weight\": \"fc.linear.weight\",\n            \"fc.bias\": \"fc.linear.bias\",\n            \"proj.weight\": \"proj.linear.weight\",\n            \"proj.bias\": \"proj.linear.bias\",\n        }\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\nclass LLaMAMLP(lit_gpt.model.LLaMAMLP):\n    def __init__(self, config: Config) -> None:\n        nn.Module.__init__(self)\n        self.fc_1 = LoRALinear(\n            config.n_embd,\n            config.intermediate_size,\n            bias=config.bias,\n            r=(config.r if config.to_mlp else 0),\n            lora_alpha=config.alpha,\n            lora_dropout=config.dropout,\n        )\n        self.fc_2 = LoRALinear(\n            config.n_embd,\n            config.intermediate_size,\n            bias=config.bias,\n            r=(config.r if config.to_mlp else 0),\n            lora_alpha=config.alpha,\n            lora_dropout=config.dropout,\n        )\n        self.proj = LoRALinear(\n            config.intermediate_size,\n            config.n_embd,\n            bias=config.bias,\n            r=(config.r if config.to_mlp else 0),\n            lora_alpha=config.alpha,\n            lora_dropout=config.dropout,\n        )\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base checkpoints.\"\"\"\n        mapping = {\n            \"fc_1.weight\": \"fc_1.linear.weight\",\n            \"fc_1.bias\": \"fc_1.linear.bias\",\n            \"fc_2.weight\": \"fc_2.linear.weight\",\n            \"fc_2.bias\": \"fc_2.linear.bias\",\n            \"proj.weight\": \"proj.linear.weight\",\n            \"proj.bias\": \"proj.linear.bias\",\n        }\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\ndef merge_lora_weights(model: GPT) -> None:\n    \"\"\"Merge LoRA weights into the full-rank weights to speed up inference.\"\"\"\n    for module in model.modules():\n        if isinstance(module, LoRALinear):\n            module.merge()\n"
  },
  {
    "path": "lit_gpt/model.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\n\"\"\"Full definition of a GPT NeoX Language Model, all of it in this single file.\n\nBased on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and\nhttps://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.\n\"\"\"\nimport math\nfrom typing import Any, List, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\n# from lightning_utilities.core.imports import RequirementCache\nfrom typing_extensions import Self\n# from flash_attn import flash_attn_func\nfrom lit_gpt.config import Config\n#from xformers.ops import SwiGLU\n#from .fused_rotary_embedding import apply_rotary_emb_func\nfrom .rotary_ebm import apply_rotary_pos_emb\n\nRoPECache = Tuple[torch.Tensor, torch.Tensor]\nKVCache = Tuple[torch.Tensor, torch.Tensor]\n# FlashAttention2Available = RequirementCache(\"flash-attn>=2.0.0.post1\")\n\n# input_pos_global = torch.arange(0, 4096, device=torch.device('cuda'))\n#import triton\n#from triton import ops\n\nclass GPT(nn.Module):\n    def __init__(self, config: Config) -> None:\n        super().__init__()\n        assert config.padded_vocab_size is not None\n        self.config = config\n\n        self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)\n        self.transformer = nn.ModuleDict(\n            dict(\n                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),\n                h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),\n                ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),\n            )\n        )\n        self.rope_cache: Optional[RoPECache] = None\n        self.mask_cache: Optional[torch.Tensor] = None\n        self.kv_caches: List[KVCache] = []\n\n    def _init_weights(self, module: nn.Module, n_layer) -> None:\n        \"\"\"Meant to be used with `gpt.apply(gpt._init_weights)`.\"\"\"\n        # GPT-NeoX  https://arxiv.org/pdf/2204.06745.pdf\n        if isinstance(module, nn.Embedding):\n            torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))\n            # RWKV: set it to 1e-4\n            # torch.nn.init.uniform_(module.weight,  -1e-4, 1e-4)\n        elif isinstance(module, nn.Linear):\n            torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))\n            if module.bias is not None:\n                torch.nn.init.zeros_(module.bias)\n        # GPT-NeoX       \n        for name, p in module.named_parameters():\n            if (name == \"proj.weight\" and isinstance(module, LLaMAMLP)) or (name == \"w3.weight\" and isinstance(module, SwiGLU) or (name==\"proj.weight\" and isinstance(module, CausalSelfAttention))):  #if use xformer swiglu, fc2 layer will be renamed to w3\n                nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd)  /  n_layer)\n        \n\n    def reset_cache(self) -> None:\n        self.kv_caches.clear()\n        if self.mask_cache is not None and self.mask_cache.device.type == \"xla\":\n            # https://github.com/Lightning-AI/lit-gpt/pull/83#issuecomment-1558150179\n            self.rope_cache = None\n            self.mask_cache = None\n\n    def forward(\n        self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None\n    ) -> torch.Tensor:\n        B, T = idx.size()\n        use_kv_cache = input_pos is not None\n\n        block_size = self.config.block_size\n        if max_seq_length is None:\n            max_seq_length = block_size\n        if use_kv_cache:  # not relevant otherwise\n            assert (\n                max_seq_length >= T\n            ), f\"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}\"\n        assert max_seq_length <= block_size, f\"Cannot attend to {max_seq_length}, block size is only {block_size}\"\n        assert block_size >= T, f\"Cannot forward sequence of length {T}, block size is only {block_size}\"\n\n        if self.rope_cache is None:\n            self.rope_cache = self.build_rope_cache(idx)\n        # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask\n        # for the kv-cache support (only during inference), we only create it in that situation\n        # this will be resolved by https://github.com/pytorch/pytorch/issues/96099\n        if use_kv_cache and self.mask_cache is None:\n            self.mask_cache = self.build_mask_cache(idx)\n\n        cos, sin = self.rope_cache\n        if use_kv_cache:\n\n            cos = cos.index_select(0, input_pos)\n            sin = sin.index_select(0, input_pos)\n            mask = self.mask_cache.index_select(2, input_pos)\n            mask = mask[:, :, :, :max_seq_length]\n        else:\n            cos = cos[:T]\n            sin = sin[:T]\n            mask = None\n            input_pos = torch.arange(0, T, device=idx.device)\n\n        # forward the model itself\n        x = self.transformer.wte(idx)  # token embeddings of shape (b, t, n_embd)\n            \n        if not use_kv_cache:\n            for block in self.transformer.h:\n                x, *_ = block(x, (cos, sin), max_seq_length, input_pos=input_pos)\n        else:\n            self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1) * 2)\n            for i, block in enumerate(self.transformer.h):\n                x, self.kv_caches[i] = block(x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i])\n\n        x = self.transformer.ln_f(x)\n        return self.lm_head(x)  # (b, t, vocab_size)\n\n    @classmethod\n    def from_name(cls, name: str, **kwargs: Any) -> Self:\n        return cls(Config.from_name(name, **kwargs))\n\n    def build_rope_cache(self, idx: torch.Tensor) -> RoPECache:\n        return build_rope_cache(\n            seq_len=self.config.block_size,\n            n_elem=int(self.config.rotary_percentage * self.config.head_size),\n            dtype=torch.bfloat16,\n            device=idx.device,\n            condense_ratio=self.config.condense_ratio,\n        )\n\n    def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor:\n        ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool)\n        return torch.tril(ones).unsqueeze(0).unsqueeze(0)\n\n    def build_kv_caches(self, idx: torch.Tensor, max_seq_length: int, rope_cache_length: int) -> List[KVCache]:\n        B = idx.size(0)\n        heads = 1 if self.config.n_query_groups == 1 else self.config.n_query_groups\n\n        k_cache_shape = (\n            B,\n            max_seq_length,\n            heads,\n            rope_cache_length + self.config.head_size - int(self.config.rotary_percentage * self.config.head_size),\n        )\n        v_cache_shape = (B, max_seq_length, heads, self.config.head_size)\n        device = idx.device\n        return [\n            (torch.zeros(k_cache_shape, device=device), torch.zeros(v_cache_shape, device=device))\n            for _ in range(self.config.n_layer)\n        ]\n\n\nclass Block(nn.Module):\n    def __init__(self, config: Config) -> None:\n        super().__init__()\n        self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)\n        self.attn = CausalSelfAttention(config)\n        if not config.shared_attention_norm:\n            self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)\n        self.mlp = config.mlp_class(config)\n        self.config = config\n    def forward(\n        self,\n        x: torch.Tensor,\n        rope: RoPECache,\n        max_seq_length: int,\n        mask: Optional[torch.Tensor] = None,\n        input_pos: Optional[torch.Tensor] = None,\n        kv_cache: Optional[KVCache] = None,\n    ) -> Tuple[torch.Tensor, Optional[KVCache]]:\n\n        n_1 = self.norm_1(x)\n        h, new_kv_cache = self.attn(n_1, rope, max_seq_length, mask, input_pos, kv_cache)\n        if self.config.parallel_residual:\n            n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)\n            x = x + h + self.mlp(n_2)\n        else:\n            if self.config.shared_attention_norm:\n                raise NotImplementedError(\n                    \"No checkpoint amongst the ones we support uses this configuration\"\n                    \" (non-parallel residual and shared attention norm).\"\n                )\n            \n            x = x + h\n            x = x + self.mlp(self.norm_2(x))\n        return x, new_kv_cache\n\n\nclass CausalSelfAttention(nn.Module):\n    def __init__(self, config: Config) -> None:\n        super().__init__()\n        shape = (config.n_head + 2 * config.n_query_groups) * config.head_size\n        # key, query, value projections for all heads, but in a batch\n        self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)\n        # output projection\n        self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)\n\n        self.config = config\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        rope: RoPECache,\n        max_seq_length: int,\n        mask: Optional[torch.Tensor] = None,\n        input_pos: Optional[torch.Tensor] = None,\n        kv_cache: Optional[KVCache] = None,\n    ) -> Tuple[torch.Tensor, Optional[KVCache]]:\n        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)\n\n        qkv = self.attn(x)\n\n        # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)\n        q_per_kv = self.config.n_head // self.config.n_query_groups\n        total_qkv = q_per_kv + 2  # each group has 1+ queries, 1 key, and 1 value\n        qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) # (B, T, n_query_groups, total_qkv, hs)\n        # qkv = qkv.permute(0, 2, 3, 1, 4)  # (B, n_query_groups, total_qkv, T, hs)\n\n        # split batched computation into three\n        q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2)\n\n        # repeat k and v if necessary\n        # Peiyuan: we do not need to do this as flash attention 2 already support GQA\n        # if self.config.n_query_groups != 1:  # doing this would require a full kv cache with MQA (inefficient!)\n        #     # for MHA this is a no-op\n        #     k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)\n        #     v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)\n\n        q = q.reshape(B,  T, -1, self.config.head_size)  # (B, T, nh_q, hs)\n        k = k.reshape(B,  T, -1, self.config.head_size)  \n        v = v.reshape(B,  T, -1, self.config.head_size)  \n\n        cos, sin = rope\n\n        # apply rope in fp32 significanly stabalize training\n        # fused rope expect (batch_size, seqlen, nheads, headdim)\n        #q = apply_rotary_emb_func(q, cos, sin, False, True)\n        #k = apply_rotary_emb_func(k, cos, sin, False, True)\n        q, k = apply_rotary_pos_emb(q, k, cos, sin, input_pos)\n        \n        # n_elem = int(self.config.rotary_percentage * self.config.head_size)\n    \n        # q_roped = apply_rope(q[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2))\n        # k_roped = apply_rope(k[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2))\n        # print( (q_roped - q).sum())\n        # q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)\n        # k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)\n\n        if kv_cache is not None:\n            cache_k, cache_v = kv_cache\n            cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype)\n            # check if reached token limit\n            if input_pos[-1] >= max_seq_length:\n                input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)\n                # shift 1 position to the left\n                cache_k = torch.roll(cache_k, -1, dims=1)\n                cache_v = torch.roll(cache_v, -1, dims=1)\n\n            k = cache_k.index_copy_(1, input_pos, k)\n            v = cache_v.index_copy_(1, input_pos, v)\n            kv_cache = k, v\n\n        y = self.scaled_dot_product_attention(q, k, v, mask=mask)\n\n        y = y.reshape(B, T, C)  # re-assemble all head outputs side by side\n\n        # output projection\n        y = self.proj(y)\n\n        return y, kv_cache\n\n    def scaled_dot_product_attention(\n        self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None\n    ):\n        scale = 1.0 / math.sqrt(self.config.head_size)\n        '''\n        if (\n            FlashAttention2Available\n            and mask is None\n            and q.device.type == \"cuda\"\n            and self.config.enable_flash_attn\n            #and q.dtype in (torch.float16, torch.bfloat16)\n        ):\n            from flash_attn import flash_attn_func\n            return flash_attn_func(q.to(self.config.flash_attn_dtype), k.to(self.config.flash_attn_dtype), v.to(self.config.flash_attn_dtype), dropout_p=0.0, softmax_scale=scale, causal=True).to(v.dtype)\n        '''\n        q = q.transpose(1, 2)\n        k = k.transpose(1, 2)\n        v = v.transpose(1, 2)\n        if q.size() != k.size():\n             k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1)\n             v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1)\n        y = torch.nn.functional.scaled_dot_product_attention(\n            q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=mask is None, scale=scale\n        )\n\n        return y.transpose(1, 2)\n\n    # Efficient implementation equivalent to the following:\n    def raw_product_attention(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:\n        # Efficient implementation equivalent to the following:\n        L, S = query.size(-2), key.size(-2)\n        scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale\n        attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)\n        if is_causal:\n            assert attn_mask is None\n            temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)\n            attn_bias.masked_fill_(temp_mask.logical_not(), float(\"-inf\"))\n            attn_bias.to(query.dtype)\n\n        if attn_mask is not None:\n            if attn_mask.dtype == torch.bool:\n                attn_mask = (~attn_mask).to(query.dtype).masked_fill_(attn_mask.logical_not(), float(\"-inf\"))\n            else:\n                attn_bias += attn_mask\n        attn_weight = query @ key.transpose(-2, -1) * scale_factor\n        attn_weight += attn_bias\n        attn_weight = torch.softmax(attn_weight, dim=-1)\n        attn_weight = torch.dropout(attn_weight, dropout_p, train=True)\n        return attn_weight @ value\n\n\ndef test_attn(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:\n    # Efficient implementation equivalent to the following:\n    L, S = query.size(-2), key.size(-2)\n    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale\n    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)\n    if is_causal:\n        assert attn_mask is None\n        temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)\n        attn_bias.masked_fill_(temp_mask.logical_not(), float(\"-inf\"))\n        attn_bias.to(query.dtype)\n\n    if attn_mask is not None:\n        if attn_mask.dtype == torch.bool:\n            attn_mask = (~attn_mask).to(query.dtype).masked_fill_(attn_mask.logical_not(), float(\"-inf\"))\n        else:\n            attn_bias += attn_mask\n    attn_weight = query @ key.transpose(-2, -1) * scale_factor\n    attn_weight += attn_bias\n    attn_weight = torch.softmax(attn_weight, dim=-1)\n    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)\n    return attn_weight @ value\n\nclass GptNeoxMLP(nn.Module):\n    def __init__(self, config: Config) -> None:\n        super().__init__()\n        self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)\n        self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.fc(x)\n        x = torch.nn.functional.gelu(x)\n        return self.proj(x)\n\n\nclass LLaMAMLP(nn.Module):  ##NOTE: changed to use torch ativation version Dec 8.\n    def __init__(self, config: Config) -> None:\n        super().__init__()\n        self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)\n        self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)\n        self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)\n        # self.swiglu = SwiGLU(config.n_embd,config.intermediate_size, bias=False, _pack_weights=False)\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x_fc_1 = self.fc_1(x)\n        x_fc_2 = self.fc_2(x)\n        x = torch.nn.functional.silu(x_fc_1) * x_fc_2\n        return self.proj(x)\n        # return self.swiglu(x)\n\n\ndef build_rope_cache(\n    seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1\n) -> RoPECache:\n    \"\"\"Enhanced Transformer with Rotary Position Embedding.\n\n    Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/\n    transformers/rope/__init__.py. MIT License:\n    https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.\n    \"\"\"\n    # $\\Theta = {\\theta_i = 10000^{\\frac{2(i-1)}{d}}, i \\in [1, 2, ..., \\frac{d}{2}]}$\n    theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem))\n\n    # Create position indexes `[0, 1, ..., seq_len - 1]`\n    seq_idx = torch.arange(seq_len, device=device) / condense_ratio\n\n    # Calculate the product of position index and $\\theta_i$\n    idx_theta = torch.outer(seq_idx, theta)\n    idx_theta = torch.cat((idx_theta, idx_theta), dim=-1)\n\n    cos, sin = torch.cos(idx_theta), torch.sin(idx_theta)\n\n    # added by peiyuan to ensure same data type with q, k, to use fused rotary embedding\n    # if dtype == torch.bfloat16:\n    #     return cos.bfloat16(), sin.bfloat16()\n    # # this is to mimic the behaviour of complex32, else we will get different results\n    # if dtype in (torch.float16, torch.bfloat16, torch.int8):\n    #     return cos.half(), sin.half()\n    return cos, sin\n\n\ndef apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n    head_size = x.size(-1)\n    x1 = x[..., : head_size // 2]  # (B, nh, T, hs/2)\n    x2 = x[..., head_size // 2 :]  # (B, nh, T, hs/2)\n    rotated = torch.cat((-x2, x1), dim=-1)  # (B, nh, T, hs)\n    roped = (x * cos) + (rotated * sin)\n    return roped.type_as(x)\n"
  },
  {
    "path": "lit_gpt/packed_dataset.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport os\nimport random\nimport struct\n\nimport numpy as np\nimport torch\nfrom torch.utils.data import IterableDataset, get_worker_info\n\ndtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16}\n\n\ndef code(dtype):\n    for k in dtypes:\n        if dtypes[k] == dtype:\n            return k\n    raise ValueError(dtype)\n\n\nHDR_MAGIC = b\"LITPKDS\"\nHDR_SIZE = 24  # bytes\n\n\nclass PackedDataset(IterableDataset):\n    def __init__(\n        self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0\n    ):\n        self._filenames = filenames\n        self._n_chunks = n_chunks\n        self._block_size = block_size\n        self._seed = seed\n        self._shuffle = shuffle\n        self._wrap = wrap\n        self._num_processes = num_processes\n        self._process_rank = process_rank\n\n    def __iter__(self):\n        worker_info = get_worker_info()\n        num_workers = worker_info.num_workers if worker_info is not None else 1\n        worker_id = worker_info.id if worker_info is not None else 0\n        num_shards = num_workers * self._num_processes\n        shard_id = self._process_rank * num_workers + worker_id\n\n        max_num_files = len(self._filenames) // num_shards * num_shards\n        filenames = self._filenames[shard_id:max_num_files:num_shards]\n\n        return PackedDatasetIterator(\n            filenames=filenames,\n            n_chunks=self._n_chunks,\n            block_size=self._block_size,\n            seed=self._seed,\n            shuffle=self._shuffle,\n            wrap=self._wrap,\n        )\n\n\nclass PackedDatasetBuilder(object):\n    def __init__(self, outdir, prefix, chunk_size, sep_token, dtype=\"auto\", vocab_size=None):\n        print(\"++++++++++{}\".format(sep_token))\n        if dtype == \"auto\":\n            if vocab_size is None:\n                raise ValueError(\"vocab_size cannot be None when dtype='auto'\")\n            if vocab_size is not None and vocab_size < 65500:\n                self._dtype = np.uint16\n            else:\n                self._dtype = np.int32\n        else:\n            self._dtype = dtype\n        self._counter = 0\n        self._chunk_size = chunk_size\n        self._outdir = outdir\n        self._prefix = prefix\n        self._sep_token = sep_token\n        self._arr = np.zeros(self._chunk_size, dtype=self._dtype)\n        self._arr.fill(self._sep_token)\n        self._idx = 0\n        self._version = 1\n        self._filenames = []\n\n    def _write_chunk(self):\n        filename = f\"{self._prefix}_{self._counter:010d}.bin\"\n        filename = os.path.join(self._outdir, filename)\n\n        with open(filename, \"wb\") as f:\n            f.write(HDR_MAGIC)\n            f.write(struct.pack(\"<Q\", self._version))\n            f.write(struct.pack(\"<B\", code(self._dtype)))\n            f.write(struct.pack(\"<Q\", self._chunk_size))\n            f.write(self._arr.tobytes(order=\"C\"))\n\n        self._filenames.append(filename)\n        self._counter += 1\n        self._arr.fill(self._sep_token)\n        self._idx = 0\n\n    @property\n    def dtype(self):\n        return self._dtype\n\n    @property\n    def filenames(self):\n        return self._filenames.copy()\n\n    def add_array(self, arr):\n        while self._idx + arr.shape[0] > self._chunk_size:\n            part_len = self._chunk_size - self._idx\n            self._arr[self._idx : self._idx + part_len] = arr[:part_len]\n            self._write_chunk()\n            arr = arr[part_len:]\n\n        arr_len = arr.shape[0]\n        self._arr[self._idx : self._idx + arr_len] = arr\n        self._idx += arr_len\n\n    def write_reminder(self):\n        self._write_chunk()\n\n\nclass PackedDatasetIterator:\n    def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap):\n        self._seed = seed\n        self._shuffle = shuffle\n        self._rng = np.random.default_rng(seed) if shuffle else None\n        self._block_idxs = None\n\n        self._wrap = wrap\n\n        # TODO: instead of filenames, we could have a single text stream\n        #       (or text file) with the sequence of all files to be\n        #       fetched/loaded.\n        self._filenames = filenames\n        self._file_idx = 0\n\n        self._n_chunks = n_chunks\n\n        self._dtype = None\n        self._block_size = block_size\n        self._n_blocks = None\n\n        self._mmaps = []\n        self._buffers = []\n\n        self._block_idxs = []\n        self._curr_idx = 0\n\n        self._load_n_chunks()\n\n    def _read_header(self, path):\n        with open(path, \"rb\") as f:\n            magic = f.read(len(HDR_MAGIC))\n            assert magic == HDR_MAGIC, \"File doesn't match expected format.\"\n            version = struct.unpack(\"<Q\", f.read(8))\n            assert version == (1,)\n            (dtype_code,) = struct.unpack(\"<B\", f.read(1))\n            dtype = dtypes[dtype_code]\n            (chunk_size,) = struct.unpack(\"<Q\", f.read(8))\n        return dtype, chunk_size\n\n    def _close_mmaps(self):\n        for mmap in self._mmaps:\n            mmap._mmap.close()\n\n    def _load_n_chunks(self):\n        self._close_mmaps()\n        self._mmaps = []\n        self._buffers = []\n\n        if self._n_chunks > len(self._filenames[self._file_idx :]):\n            # if not self._wrap:\n            #     raise StopIteration\n            self._file_idx = 0\n        actual_n_chunks = min(self._n_chunks, len(self._filenames[self._file_idx :]))\n        for i in range(actual_n_chunks):\n            filename = self._filenames[self._file_idx + i]\n            if self._dtype is None:\n                self._dtype, self._chunk_size = self._read_header(filename)\n                self._n_blocks = self._chunk_size // self._block_size\n            # TODO: check header matches with previous files\n            mmap = np.memmap(filename, mode=\"r\", order=\"C\", offset=HDR_SIZE)\n            self._mmaps.append(mmap)\n            self._buffers.append(memoryview(mmap))\n\n        self._file_idx += actual_n_chunks\n        n_all_blocks = actual_n_chunks * self._n_blocks\n\n        self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks)\n\n        self._curr_idx = 0\n\n    def __del__(self):\n        self._close_mmaps()\n        del self._mmaps\n        del self._buffers\n\n    def __iter__(self):\n        return self\n\n    def __next__(self):\n        if self._curr_idx >= len(self._block_idxs):\n            self._load_n_chunks()\n            # TODO: trigger fetching next next n_chunks if remote\n        block_idx = self._block_idxs[self._curr_idx]\n        chunk_id = block_idx // self._n_blocks\n        buffer = self._buffers[chunk_id]\n        elem_id = (block_idx % self._n_blocks) * self._block_size\n        offset = np.dtype(self._dtype).itemsize * elem_id\n        arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)\n        self._curr_idx += 1\n        return torch.from_numpy(arr.astype(np.int64))\n\n\nclass CombinedDataset(IterableDataset):\n    def __init__(self, datasets, seed, weights=None):\n        self._seed = seed\n        self._datasets = datasets\n        self._weights = weights\n        n_datasets = len(datasets)\n        if weights is None:\n            self._weights = [1 / n_datasets] * n_datasets\n\n    def __iter__(self):\n        return CombinedDatasetIterator(self._datasets, self._seed, self._weights)\n\n\nclass CombinedDatasetIterator:\n    def __init__(self, datasets, seed, weights):\n        self._datasets = [iter(el) for el in datasets]\n        self._weights = weights\n        self._rng = random.Random(seed)\n\n    def __next__(self):\n        (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1)\n        return next(dataset)\n"
  },
  {
    "path": "lit_gpt/rmsnorm.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport torch\n# Copyright (c) 2022, Tri Dao.\n# Adapted from https://github.com/Dao-AILab/flash-attention/blob/7a983df74215e035e566e37125b0a71e3618f39d/flash_attn/ops/layer_norm.py#L16\n\nimport torch\nfrom torch.nn import init\n\n    \nclass RMSNorm(torch.nn.Module):\n    \"\"\"Root Mean Square Layer Normalization.\n\n    Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:\n    https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.\n    \"\"\"\n\n    def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:\n        super().__init__()\n        self.weight = torch.nn.Parameter(torch.ones(size))\n        self.eps = eps\n        self.dim = dim\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        # NOTE: the original RMSNorm paper implementation is not equivalent\n        norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)\n        x_normed = x * torch.rsqrt(norm_x + self.eps)\n        return self.weight * x_normed\n\n    def reset_parameters(self):\n        torch.nn.init.ones_(self.weight)\n\ntry:\n    import apex\n    class FusedRMSNorm(apex.normalization.FusedRMSNorm):\n        def __init__(self, size: int, dim: int = -1, eps: float = 1e-5):\n            super().__init__(size, eps=eps, elementwise_affine=True)\n            self.eps = eps\n            self.weight = torch.nn.Parameter(torch.ones(size))\n            self.dim = dim\n            self.reset_parameters()\n\n        def reset_parameters(self):\n            init.ones_(self.weight)\n\n        # def forward(self, x):\n        #     return rms_norm(x, self.weight, self.eps)\nexcept:\n    print(\"Fail to import FusedRMSNorm, use RMSNorm instead.\")\n    FusedRMSNorm = RMSNorm\n"
  },
  {
    "path": "lit_gpt/rotary_ebm.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport torch\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n \n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`):\n            The position indices of the tokens corresponding to the query and key tensors. For example, this can be\n            used to pass offsetted position ids when working with a KV-cache.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos[position_ids].unsqueeze(unsqueeze_dim)\n    sin = sin[position_ids].unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed.type_as(q), k_embed.type_as(k)\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n \n"
  },
  {
    "path": "lit_gpt/speed_monitor.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport time\nfrom collections import deque\nfrom contextlib import nullcontext\nfrom typing import Any, Callable, Deque, Dict, Optional\n\nimport torch\nfrom lightning import Callback, Fabric, LightningModule, Trainer\nfrom lightning.fabric.utilities.rank_zero import rank_zero_only as fabric_rank_zero_only\nfrom lightning.pytorch.utilities.rank_zero import rank_zero_only as trainer_rank_zero_only\n#from torch.utils.flop_counter import FlopCounterMode\nimport math\nfrom lit_gpt import GPT, Config\nfrom lit_gpt.utils import num_parameters\n\nGPU_AVAILABLE_FLOPS = {\n    # source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet\n    # nvidia publishes spec sheet with a 2x sparsity factor\n    \"h100-sxm\": {\n        \"64-true\": 67e12,\n        \"32-true\": 67e12,\n        \"16-true\": 1.979e15 / 2,\n        \"16-mixed\": 1.979e15 / 2,\n        \"bf16-true\": 1.979e15 / 2,\n        \"bf16-mixed\": 1.979e15 / 2,\n        \"8-true\": 3.958e15 / 2,\n        \"8-mixed\": 3.958e15 / 2,\n    },\n    \"h100-pcie\": {\n        \"64-true\": 51e12,\n        \"32-true\": 51e12,\n        \"16-true\": 1.513e15 / 2,\n        \"16-mixed\": 1.513e15 / 2,\n        \"bf16-true\": 1.513e15 / 2,\n        \"bf16-mixed\": 1.513e15 / 2,\n        \"8-true\": 3.026e15 / 2,\n        \"8-mixed\": 3.026e15 / 2,\n    },\n    # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf\n    # sxm and pcie have same flop counts\n    \"a100\": {\n        \"64-true\": 19.5e12,\n        \"32-true\": 19.5e12,\n        \"16-true\": 312e12,\n        \"16-mixed\": 312e12,\n        \"bf16-true\": 312e12,\n        \"bf16-mixed\": 312e12,\n    },\n    # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/a10-datasheet.pdf\n    \"a10g\": {\"32-true\": 31.2e12, \"16-true\": 125e12, \"16-mixed\": 125e12, \"bf16-true\": 125e12, \"bf16-mixed\": 125e12},\n    # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf\n    \"v100-sxm\": {\"64-true\": 7.8e12, \"32-true\": 15.7e12, \"16-true\": 125e12, \"16-mixed\": 125e12},\n    \"v100-pcie\": {\"64-true\": 7e12, \"32-true\": 14e12, \"16-true\": 112e12, \"16-mixed\": 112e12},\n    \"v100s-pcie\": {\"64-true\": 8.2e12, \"32-true\": 16.4e12, \"16-true\": 130e12, \"16-mixed\": 130e12},\n    # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf\n    # sxm and pcie have same flop counts\n    \"t4\": {\"32-true\": 8.1e12, \"16-true\": 65e12, \"16-mixed\": 65e12, \"8-true\": 130e12, \"int4\": 260e12},\n    # https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/quadro-rtx-5000-data-sheet-us-nvidia-704120-r4-web.pdf\n    \"quadro rtx 5000\": {\"32-true\": 11.2e12, \"16-true\": 89.2e12, \"16-mixed\": 89.2e12},\n}\n\nTPU_AVAILABLE_FLOPS = {\n    # flop count for each TPU generation is the same for all precisions\n    # since bfloat16 precision is always used for performing matrix operations\n    # for more info: https://cloud.google.com/tpu/docs/bfloat16#choosing_bfloat16\n    # source: https://arxiv.org/pdf/1907.10701.pdf\n    \"v2\": 45e12,\n    # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v3\n    \"v3\": 123e12,\n    # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4\n    \"v4\": 275e12,\n}\n\n\ndef get_flops_available(device: torch.device, precision: str) -> Optional[float]:\n    if device.type == \"cuda\":\n        device_name = torch.cuda.get_device_name(device).lower()\n        if \"h100\" in device_name and \"hbm3\" in device_name:\n            device_name = \"h100-sxm\"\n        elif \"h100\" in device_name and (\"pcie\" in device_name or \"hbm2e\" in device_name):\n            device_name = \"h100-pcie\"\n        elif \"a100\" in device_name:\n            device_name = \"a100\"\n        elif \"a10g\" in device_name:\n            device_name = \"a10g\"\n        elif \"v100-sxm\" in device_name:\n            device_name = \"v100-sxm\"\n        elif \"v100-pcie\" in device_name:\n            device_name = \"v100-pcie\"\n        elif \"t4\" in device_name:\n            device_name = \"t4\"\n        elif \"quadro rtx 5000\" in device_name:\n            device_name = \"quadro rtx 5000\"\n        else:\n            device_name = None\n\n        if device_name is not None:\n            try:\n                return int(GPU_AVAILABLE_FLOPS[device_name][precision])\n            except KeyError:\n                raise KeyError(\n                    f\"flop count not found for {device_name} with precision: {precision}; \"\n                    \"MFU cannot be calculated and reported.\"\n                )\n    elif device.type == \"xla\":\n        from torch_xla.experimental import tpu\n\n        device_name = tpu.get_tpu_env()[\"TYPE\"].lower()\n        try:\n            return int(TPU_AVAILABLE_FLOPS[device_name])\n        except KeyError:\n            raise KeyError(\n                f\"flop count not found for {device_name} with precision: {precision}; \"\n                \"MFU cannot be calculated and reported.\"\n            )\n\n    return None\n\n\n# Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820cb75023b9eb7c46fdfd25273712abd0/composer/callbacks/speed_monitor.py\n\n\nclass SpeedMonitorBase:\n    \"\"\"Logs the training throughput and utilization.\n\n    +-------------------------------------+-----------------------------------------------------------+\n    | Key                                 | Logged data                                               |\n    +=====================================+===========================================================+\n    |                                     | Rolling average (over `window_size` most recent           |\n    | `throughput/batches_per_sec`        | batches) of the number of batches processed per second    |\n    |                                     |                                                           |\n    +-------------------------------------+-----------------------------------------------------------+\n    |                                     | Rolling average (over `window_size` most recent           |\n    | `throughput/samples_per_sec`        | batches) of the number of samples processed per second    |\n    |                                     |                                                           |\n    +-------------------------------------+-----------------------------------------------------------+\n    |                                     | Rolling average (over `window_size` most recent           |\n    | `throughput/tokens_per_sec`         | batches) of the number of tokens processed per second.    |\n    |                                     | This may include padding depending on dataset             |\n    +-------------------------------------+-----------------------------------------------------------+\n    |                                     | Estimates flops by `flops_per_batch * batches_per_sec`    |\n    | `throughput/flops_per_sec`          |                                                           |\n    |                                     |                                                           |\n    +-------------------------------------+-----------------------------------------------------------+\n    | `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size        |\n    +-------------------------------------+-----------------------------------------------------------+\n    | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size        |\n    +-------------------------------------+-----------------------------------------------------------+\n    |                                     | `throughput/tokens_per_sec` divided by world size. This   |\n    | `throughput/device/tokens_per_sec`  | may include pad tokens depending on dataset               |\n    |                                     |                                                           |\n    +-------------------------------------+-----------------------------------------------------------+\n    |                                     | `throughput/flops_per_sec` divided by world size. Only    |\n    | `throughput/device/flops_per_sec`   | logged when model has attribute `flops_per_batch`         |\n    |                                     |                                                           |\n    +-------------------------------------+-----------------------------------------------------------+\n    |                                     | `throughput/device/flops_per_sec` divided by world size.  |\n    | `throughput/device/mfu`             |                                                           |\n    |                                     |                                                           |\n    +-------------------------------------+-----------------------------------------------------------+\n    | `time/train`                        | Total elapsed training time                               |\n    +-------------------------------------+-----------------------------------------------------------+\n    | `time/val`                          | Total elapsed validation time                             |\n    +-------------------------------------+-----------------------------------------------------------+\n    | `time/total`                        | Total elapsed time (time/train + time/val)                |\n    +-------------------------------------+-----------------------------------------------------------+\n\n    Notes:\n        - The implementation assumes that devices are homogeneous as it normalizes by the world size.\n        - Tokens/sec, flops/sec and MFU do not account for padding tokens if present. We suggest using samples/sec or\n          batches/sec to measure throughput under this circumstance.\n        - Be careful when comparing MFU numbers across projects, as this will highly depend on the ``flops_per_batch``.\n          There is no widespread, realistic, and reliable implementation to compute them.\n          We suggest using our ``measure_flops`` function, but many other works will use ``estimated_flops`` which\n          will almost always be an overestimate when compared to the true value.\n\n    Args:\n        window_size (int, optional): Number of batches to use for a rolling average of throughput.\n            Defaults to 100.\n        time_unit (str, optional): Time unit to use for `time` logging. Can be one of\n            'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'.\n    \"\"\"\n\n    def __init__(\n        self,\n        flops_available: float,\n        log_dict: Callable[[Dict, int], None],\n        window_size: int = 100,\n        time_unit: str = \"hours\",\n        log_iter_interval: int = 1,\n    ):\n        self.flops_available = flops_available\n        self.log_dict = log_dict\n        self.log_iter_interval = log_iter_interval\n        # Track the batch num samples and wct to compute throughput over a window of batches\n        self.history_samples: Deque[int] = deque(maxlen=window_size + 1)\n        self.history_training_loss: Deque[int] = deque(maxlen=log_iter_interval)\n        self.history_wct: Deque[float] = deque(maxlen=window_size + 1)\n        self.history_lengths: Deque[int] = deque(maxlen=window_size + 1)\n        self.history_flops: Deque[int] = deque(maxlen=window_size + 1)\n\n        self.divider = 1\n        if time_unit == \"seconds\":\n            self.divider = 1\n        elif time_unit == \"minutes\":\n            self.divider = 60\n        elif time_unit == \"hours\":\n            self.divider = 60 * 60\n        elif time_unit == \"days\":\n            self.divider = 60 * 60 * 24\n        else:\n            raise ValueError(\n                f'Invalid time_unit: {time_unit}. Must be one of \"seconds\", \"minutes\", \"hours\", or \"days\".'\n            )\n\n        # Keep track of time spent evaluating\n        self.total_eval_wct = 0.0\n        self.iter = -1\n\n    def on_train_batch_end(\n        self,\n        samples: int,  # total samples seen (per device)\n        train_elapsed: float,  # total training time (seconds)\n        world_size: int,\n        step_count: int,\n        flops_per_batch: Optional[int] = None,  # (per device)\n        lengths: Optional[int] = None,  # total length of the samples seen (per device)\n        train_loss: Optional[float] = None,\n    ):\n        self.iter += 1\n        metrics = {}\n\n        self.history_samples.append(samples)\n        self.history_training_loss.append(train_loss)\n        if lengths is not None:\n            self.history_lengths.append(lengths)\n            # if lengths are passed, there should be as many values as samples\n            assert len(self.history_samples) == len(self.history_lengths)\n        self.history_wct.append(train_elapsed)\n        if len(self.history_wct) == self.history_wct.maxlen:\n            elapsed_batches = len(self.history_samples) - 1\n            elapsed_samples = self.history_samples[-1] - self.history_samples[0]\n            elapsed_wct = self.history_wct[-1] - self.history_wct[0]\n            samples_per_sec = elapsed_samples * world_size / elapsed_wct\n            dev_samples_per_sec = elapsed_samples / elapsed_wct\n            metrics.update(\n                {\n                    \"throughput/batches_per_sec\": elapsed_batches * world_size / elapsed_wct,\n                    \"throughput/samples_per_sec\": samples_per_sec,\n                    \"throughput/device/batches_per_sec\": elapsed_batches / elapsed_wct,\n                    \"throughput/device/samples_per_sec\": dev_samples_per_sec,\n                }\n            )\n            if lengths is not None:\n                elapsed_lengths = int(self.history_lengths[-1]) - int(self.history_lengths[0])\n                avg_length = elapsed_lengths / elapsed_batches  \n                metrics.update(\n                    {\n                        \"throughput/tokens_per_sec\": samples_per_sec * avg_length,\n                        \"throughput/device/tokens_per_sec\": dev_samples_per_sec * avg_length,\n                        \"total_tokens\": avg_length * world_size * samples,\n                    }\n                )\n                if train_loss is not None:\n                    avg_loss = sum(self.history_training_loss) / len(self.history_training_loss)\n                    metrics.update(\n                        {\n                            \"metric/train_loss\": avg_loss,\n                            \"metric/train_ppl\": math.exp(avg_loss)\n                        }\n                    )\n\n        if flops_per_batch is not None:\n            # sum of flops per batch across ranks\n            self.history_flops.append(flops_per_batch * world_size)\n        if len(self.history_flops) == self.history_flops.maxlen:\n            elapsed_flops = sum(self.history_flops) - self.history_flops[0]\n            elapsed_wct = self.history_wct[-1] - self.history_wct[0]\n            flops_per_sec = elapsed_flops / elapsed_wct\n            device_flops_per_sec = flops_per_sec / world_size\n            metrics.update(\n                {\"throughput/flops_per_sec\": flops_per_sec, \"throughput/device/flops_per_sec\": device_flops_per_sec}\n            )\n            if self.flops_available:\n                metrics[\"throughput/device/mfu\"] = device_flops_per_sec / self.flops_available\n\n        metrics.update(\n            {\n                \"time/train\": train_elapsed / self.divider,\n                \"time/val\": self.total_eval_wct / self.divider,\n                \"time/total\": (train_elapsed + self.total_eval_wct) / self.divider,\n                \"samples\": samples,\n            }\n        )\n        if self.iter % self.log_iter_interval == 0:\n            self.log_dict(metrics, step_count)\n\n    def eval_end(self, eval_elapsed: float):\n        self.total_eval_wct += eval_elapsed  # seconds\n\n\nclass SpeedMonitorFabric(SpeedMonitorBase):\n    def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None:\n        # TODO: this will not work properly if a precision plugin is passed to Fabric\n        flops_available = get_flops_available(fabric.device, fabric._connector._precision_input)\n        super().__init__(flops_available, fabric.log_dict, *args, **kwargs)\n\n    @fabric_rank_zero_only\n    def on_train_batch_end(self, *args: Any, **kwargs: Any):\n        super().on_train_batch_end(*args, **kwargs)\n\n\nclass SpeedMonitorCallback(Callback):\n    def __init__(self, length_fn: Callable[[Any], int], batch_size: int, **kwargs: Any) -> None:\n        super().__init__()\n        self.speed_monitor: Optional[SpeedMonitorBase] = None\n        self.speed_monitor_kwargs = kwargs\n        self.length_fn = length_fn\n        self.batch_size = batch_size\n        self.eval_t0: int = 0\n        self.train_t0: int = 0\n        self.total_lengths: int = 0\n\n    def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:\n        if self.speed_monitor is not None:\n            return  # already setup\n        # TODO: this will not work properly if a precision plugin is passed to Trainer\n        flops_available = get_flops_available(\n            trainer.strategy.root_device, trainer._accelerator_connector._precision_flag\n        )\n        self.speed_monitor = SpeedMonitorBase(flops_available, trainer.logger.log_metrics, **self.speed_monitor_kwargs)\n\n    @trainer_rank_zero_only\n    def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:\n        if trainer.fit_loop._should_accumulate():\n            return\n\n        self.train_t0 = time.perf_counter()\n\n    @trainer_rank_zero_only\n    def on_train_batch_end(\n        self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int\n    ) -> None:\n        self.total_lengths += self.length_fn(batch)\n        if trainer.fit_loop._should_accumulate():\n            return\n        train_elapsed = time.perf_counter() - self.train_t0\n        assert self.speed_monitor is not None\n        iter_num = trainer.fit_loop.total_batch_idx\n        assert (measured_flops := pl_module.measured_flops) is not None\n        self.speed_monitor.on_train_batch_end(\n            (iter_num + 1) * self.batch_size,\n            train_elapsed,\n            # this assumes that device FLOPs are the same and that all devices have the same batch size\n            trainer.world_size,\n            flops_per_batch=measured_flops,\n            lengths=self.total_lengths,\n        )\n\n    @trainer_rank_zero_only\n    def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:\n        self.eval_t0 = time.perf_counter()\n\n    @trainer_rank_zero_only\n    def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:\n        eval_elapsed = time.perf_counter() - self.eval_t0\n        assert self.speed_monitor is not None\n        self.speed_monitor.eval_end(eval_elapsed)\n\n\ndef flops_per_param(config: Config, n_params: int) -> int:\n    flops_per_token = 2 * n_params  # each parameter is used for a MAC (2 FLOPS) per network operation\n    # this assumes that all samples have a fixed length equal to the block size\n    # which is most likely false during finetuning\n    flops_per_seq = flops_per_token * config.block_size\n    attn_flops_per_seq = config.n_layer * 2 * 2 * (config.n_embd * (config.block_size**2))\n    return flops_per_seq + attn_flops_per_seq\n\n\ndef estimate_flops(model: GPT) -> int:\n    \"\"\"Measures estimated FLOPs for MFU.\n\n    Refs:\n        * https://ar5iv.labs.arxiv.org/html/2205.05198#A1\n        * https://ar5iv.labs.arxiv.org/html/2204.02311#A2\n    \"\"\"\n    # using all parameters for this is a naive over estimation because not all model parameters actually contribute to\n    # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage\n    # (~10%) compared to the measured FLOPs, making those lower but more realistic.\n    # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.\n    n_trainable_params = num_parameters(model, requires_grad=True)\n    trainable_flops = flops_per_param(model.config, n_trainable_params)\n    # forward + backward + gradients (assumes no gradient accumulation)\n    ops_per_step = 3 if model.training else 1\n    n_frozen_params = num_parameters(model, requires_grad=False)\n    frozen_flops = flops_per_param(model.config, n_frozen_params)\n    # forward + backward\n    frozen_ops_per_step = 2 if model.training else 1\n    return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops\n\n\ndef measure_flops(model: GPT, x: torch.Tensor) -> int:\n    \"\"\"Measures real FLOPs for HFU\"\"\"\n    flop_counter = FlopCounterMode(model, display=False)\n    ctx = nullcontext() if model.training else torch.no_grad()\n    with ctx, flop_counter:\n        y = model(x)\n        if model.training:\n            y.sum().backward()\n    return flop_counter.get_total_flops()\n"
  },
  {
    "path": "lit_gpt/tokenizer.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport json\nfrom pathlib import Path\nfrom typing import Optional\n\nimport torch\n\n\nclass Tokenizer:\n    def __init__(self, checkpoint_dir: Path) -> None:\n        # some checkpoints have both files, `.model` takes precedence\n        if (vocabulary_path := checkpoint_dir / \"tokenizer.model\").is_file():\n            from sentencepiece import SentencePieceProcessor\n\n            self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))\n            self.backend = \"sentencepiece\"\n            self.bos_id = self.processor.bos_id()\n            self.eos_id = self.processor.eos_id()\n        elif (vocabulary_path := checkpoint_dir / \"tokenizer.json\").is_file():\n            from tokenizers import Tokenizer as HFTokenizer\n\n            self.processor = HFTokenizer.from_file(str(vocabulary_path))\n            self.backend = \"huggingface\"\n            with open(checkpoint_dir / \"tokenizer_config.json\") as fp:\n                config = json.load(fp)\n            self.eos_id = self.token_to_id(config[\"eos_token\"])\n            bos_token = config.get(\"bos_token\")\n            self.bos_id = self.token_to_id(bos_token) if bos_token is not None else self.eos_id\n        else:\n            raise NotImplementedError\n\n    @property\n    def vocab_size(self) -> int:\n        if self.backend == \"huggingface\":\n            return self.processor.get_vocab_size(with_added_tokens=False)\n        if self.backend == \"sentencepiece\":\n            return self.processor.vocab_size()\n        raise RuntimeError\n\n    def token_to_id(self, token: str) -> int:\n        if self.backend == \"huggingface\":\n            id_ = self.processor.token_to_id(token)\n        elif self.backend == \"sentencepiece\":\n            id_ = self.processor.piece_to_id(token)\n        else:\n            raise RuntimeError\n        if id_ is None:\n            raise ValueError(f\"token {token!r} not found in the collection.\")\n        return id_\n\n    def encode(\n        self,\n        string: str,\n        device: Optional[torch.device] = None,\n        bos: bool = False,\n        eos: bool = True,\n        max_length: int = -1,\n    ) -> torch.Tensor:\n        if self.backend == \"huggingface\":\n            tokens = self.processor.encode(string).ids\n        elif self.backend == \"sentencepiece\":\n            tokens = self.processor.encode(string)\n        else:\n            raise RuntimeError\n        if bos:\n            bos_id = self.bos_id\n            if bos_id is None:\n                raise NotImplementedError(\"This tokenizer does not defined a bos token\")\n            tokens = [bos_id] + tokens\n        if eos:\n            tokens = tokens + [self.eos_id]\n        if max_length > 0:\n            tokens = tokens[:max_length]\n        return torch.tensor(tokens, dtype=torch.int, device=device)\n\n    def decode(self, tensor: torch.Tensor) -> str:\n        tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()\n        return self.processor.decode(tokens)\n"
  },
  {
    "path": "lit_gpt/utils.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\n\"\"\"Utility functions for training and inference.\"\"\"\n\nimport pickle\nimport sys\nimport warnings\nfrom contextlib import contextmanager\nfrom functools import partial\nfrom io import BytesIO\nfrom pathlib import Path\nfrom types import MethodType\nfrom typing import Any, Dict, List, Mapping, Optional, Type, TypeVar, Union\n\nimport torch\nimport torch.nn as nn\nfrom lightning.fabric.loggers import CSVLogger\nfrom torch.serialization import normalize_storage_type\n\n\ndef find_multiple(n: int, k: int) -> int:\n    assert k > 0\n    if n % k == 0:\n        return n\n    return n + k - (n % k)\n\n\ndef num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:\n    return sum(p.numel() for p in module.parameters() if requires_grad is None or p.requires_grad == requires_grad)\n\n\n@contextmanager\ndef quantization(mode: Optional[str] = None):\n    if mode is None:\n        yield\n        return\n\n    if mode == \"bnb.int8\":\n        from quantize.bnb import InferenceLinear8bitLt\n\n        quantized_linear_cls = InferenceLinear8bitLt\n    elif mode == \"bnb.fp4\":\n        from quantize.bnb import Linear4bit\n\n        # Use a class instead `functools.partial` to respect `isinstance` checks and attribute accesses\n        class QuantizedLinear(Linear4bit):\n            def __init__(self, *args, **kwargs):\n                super().__init__(*args, quant_type=\"fp4\", compress_statistics=False, **kwargs)\n\n        quantized_linear_cls = QuantizedLinear\n    elif mode == \"bnb.fp4-dq\":\n        from quantize.bnb import Linear4bit\n\n        class QuantizedLinear(Linear4bit):\n            def __init__(self, *args, **kwargs):\n                super().__init__(*args, quant_type=\"fp4\", compress_statistics=True, **kwargs)\n\n        quantized_linear_cls = QuantizedLinear\n    elif mode == \"bnb.nf4\":\n        from quantize.bnb import Linear4bit\n\n        class QuantizedLinear(Linear4bit):\n            def __init__(self, *args, **kwargs):\n                super().__init__(*args, quant_type=\"nf4\", compress_statistics=False, **kwargs)\n\n        quantized_linear_cls = QuantizedLinear\n    elif mode == \"bnb.nf4-dq\":\n        from quantize.bnb import Linear4bit\n\n        class QuantizedLinear(Linear4bit):\n            def __init__(self, *args, **kwargs):\n                super().__init__(*args, quant_type=\"nf4\", compress_statistics=True, **kwargs)\n\n        quantized_linear_cls = QuantizedLinear\n    elif mode == \"gptq.int4\":\n        from quantize.gptq import ColBlockQuantizedLinear\n\n        class QuantizedLinear(ColBlockQuantizedLinear):\n            def __init__(self, *args, **kwargs):\n                super().__init__(*args, bits=4, tile_cols=-1, **kwargs)\n\n        quantized_linear_cls = QuantizedLinear\n    else:\n        raise ValueError(f\"Unknown quantization mode: {mode}\")\n\n    torch_linear_cls = torch.nn.Linear\n    torch.nn.Linear = quantized_linear_cls\n    yield\n    torch.nn.Linear = torch_linear_cls\n\n\n# this is taken from torchhacks https://github.com/lernapparat/torchhacks\n\n\nclass NotYetLoadedTensor:\n    def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args):\n        self.metatensor = metatensor\n        self.archiveinfo = archiveinfo\n        self.storageinfo = storageinfo\n        self.rebuild_args = rebuild_args\n\n    @classmethod\n    def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None):\n        ret = func(*args)\n        if isinstance(ret, NotYetLoadedTensor):\n            old_lt = ret._load_tensor\n\n            def _load_tensor():\n                t = old_lt()\n                return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state)\n\n            ret._load_tensor = _load_tensor\n            return ret\n        return torch._tensor._rebuild_from_type_v2(func, new_type, args, state)\n\n    @classmethod\n    def rebuild_parameter(cls, data, requires_grad, backward_hooks, *, archiveinfo=None):\n        if isinstance(data, NotYetLoadedTensor):\n            old_lt = data._load_tensor\n\n            def _load_tensor():\n                t = old_lt()\n                return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks)\n\n            data._load_tensor = _load_tensor\n            return data\n        return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks)\n\n    @classmethod\n    def rebuild_tensor_v2(\n        cls, storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None, *, archiveinfo=None\n    ):\n        rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata)\n        metatensor = torch._utils._rebuild_tensor_v2(\n            storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata\n        )\n        storageinfo = storage.archiveinfo\n        return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args)\n\n    def _load_tensor(self):\n        name, storage_cls, fn, device, size = self.storageinfo\n        dtype = self.metatensor.dtype\n\n        uts = (\n            self.archiveinfo.zipfile_context.zf.get_storage_from_record(\n                f\"data/{fn}\", size * torch._utils._element_size(dtype), torch.UntypedStorage\n            )\n            ._typed_storage()\n            ._untyped_storage\n        )\n        with warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            storage = torch.storage.TypedStorage(wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True)\n        return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args)\n\n    @classmethod\n    def __torch_function__(cls, func, types, args=(), kwargs=None):\n        if kwargs is None:\n            kwargs = {}\n        loaded_args = [(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args]\n        return func(*loaded_args, **kwargs)\n        # gc.collect would be costly here, maybe do it optionally\n\n    def __getattr__(self, name):\n        # properties\n        ## TODO: device, is_...??\n        ## TODO: mH, mT, H, T, data, imag, real\n        ## name ???\n        if name in {\n            \"dtype\",\n            \"grad\",\n            \"grad_fn\",\n            \"layout\",\n            \"names\",\n            \"ndim\",\n            \"output_nr\",\n            \"requires_grad\",\n            \"retains_grad\",\n            \"shape\",\n            \"volatile\",\n        }:\n            return getattr(self.metatensor, name)\n        if name in {\"size\"}:\n            return getattr(self.metatensor, name)\n        # materializing with contiguous is needed for quantization\n        if name in {\"contiguous\"}:\n            return getattr(self._load_tensor(), name)\n\n        raise AttributeError(f\"{type(self)} does not have {name}\")\n\n    def __repr__(self):\n        return f\"NotYetLoadedTensor({repr(self.metatensor)})\"\n\n\nclass LazyLoadingUnpickler(pickle.Unpickler):\n    def __init__(self, file, zipfile_context):\n        super().__init__(file)\n        self.zipfile_context = zipfile_context\n\n    def find_class(self, module, name):\n        res = super().find_class(module, name)\n        if module == \"torch._utils\" and name == \"_rebuild_tensor_v2\":\n            return partial(NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self)\n        if module == \"torch._tensor\" and name == \"_rebuild_from_type_v2\":\n            return partial(NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self)\n        if module == \"torch._utils\" and name == \"_rebuild_parameter\":\n            return partial(NotYetLoadedTensor.rebuild_parameter, archiveinfo=self)\n        return res\n\n    def persistent_load(self, pid):\n        name, cls, fn, device, size = pid\n        with warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            s = torch.storage.TypedStorage(dtype=cls().dtype, device=\"meta\")\n        s.archiveinfo = pid\n        return s\n\n\nclass lazy_load:\n    def __init__(self, fn):\n        self.zf = torch._C.PyTorchFileReader(str(fn))\n        with BytesIO(self.zf.get_record(\"data.pkl\")) as pkl:\n            mup = LazyLoadingUnpickler(pkl, self)\n            self.sd = mup.load()\n\n    def __enter__(self):\n        return self.sd\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        del self.zf  # I don't think there is a way to force closing...\n        self.zf = None\n\n\ndef check_valid_checkpoint_dir(checkpoint_dir: Path) -> None:\n    files = {\n        \"lit_model.pth\": (checkpoint_dir / \"lit_model.pth\").is_file(),\n        \"lit_config.json\": (checkpoint_dir / \"lit_config.json\").is_file(),\n        \"tokenizer.json OR tokenizer.model\": (checkpoint_dir / \"tokenizer.json\").is_file() or (\n            checkpoint_dir / \"tokenizer.model\"\n        ).is_file(),\n        \"tokenizer_config.json\": (checkpoint_dir / \"tokenizer_config.json\").is_file(),\n    }\n    if checkpoint_dir.is_dir():\n        if all(files.values()):\n            # we're good\n            return\n        problem = f\" is missing the files: {[f for f, exists in files.items() if not exists]!r}\"\n    else:\n        problem = \" is not a checkpoint directory\"\n\n    # list locally available checkpoints\n    available = list(Path(\"checkpoints\").glob(\"*/*\"))\n    if available:\n        options = \"\\n --checkpoint_dir \".join([\"\"] + [repr(str(p.resolve())) for p in available])\n        extra = f\"\\nYou have downloaded locally:{options}\\n\"\n    else:\n        extra = \"\"\n\n    error_message = (\n        f\"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}.\"\n        \"\\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\\n\"\n        f\"{extra}\\nSee all download options by running:\\n python scripts/download.py\"\n    )\n    print(error_message, file=sys.stderr)\n    raise SystemExit(1)\n\n\nclass SavingProxyForStorage:\n    def __init__(self, obj, saver, protocol_version=5):\n        self.protocol_version = protocol_version\n        self.saver = saver\n        if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):\n            raise TypeError(f\"expected storage, not {type(obj)}\")\n\n        # this logic is taken from PyTorch 2.0+ torch/serialization.py\n        if isinstance(obj, torch.storage.TypedStorage):\n            # PT upstream wants to deprecate this eventually...\n            storage = obj._untyped_storage\n            storage_type_str = obj._pickle_storage_type()\n            storage_type = getattr(torch, storage_type_str)\n            storage_numel = obj._size()\n        else:\n            storage = obj\n            storage_type = normalize_storage_type(type(obj))\n            storage_numel = storage.nbytes()\n\n        storage_key = saver._write_storage_and_return_key(storage)\n        location = torch.serialization.location_tag(storage)\n\n        self.storage_info = (\"storage\", storage_type, storage_key, location, storage_numel)\n\n    def __reduce_ex__(self, protocol_version):\n        assert False, \"this should be handled with out of band\"\n\n\nclass SavingProxyForTensor:\n    def __init__(self, tensor, saver, protocol_version=5):\n        self.protocol_version = protocol_version\n        self.reduce_ret_fn, (storage, *other_reduce_args) = tensor.__reduce_ex__(protocol_version)\n        assert isinstance(storage, torch.storage.TypedStorage), \"Please check for updates\"\n        storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)\n        self.reduce_args = (storage_proxy, *other_reduce_args)\n\n    def __reduce_ex__(self, protocol_version):\n        if protocol_version != self.protocol_version:\n            raise RuntimeError(f\"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}\")\n        return self.reduce_ret_fn, self.reduce_args\n\n\nclass IncrementalPyTorchPickler(pickle.Pickler):\n    def __init__(self, saver, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.storage_dtypes = {}\n        self.saver = saver\n        self.id_map = {}\n\n    # this logic is taken from PyTorch 2.0+ torch/serialization.py\n    def persistent_id(self, obj):\n        # FIXME: the docs say that persistent_id should only return a string\n        # but torch store returns tuples. This works only in the binary protocol\n        # see\n        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects\n        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537\n        if isinstance(obj, SavingProxyForStorage):\n            return obj.storage_info\n\n        if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):\n            if isinstance(obj, torch.storage.TypedStorage):\n                # TODO: Once we decide to break serialization FC, this case\n                # can be deleted\n                storage = obj._untyped_storage\n                storage_dtype = obj.dtype\n                storage_type_str = obj._pickle_storage_type()\n                storage_type = getattr(torch, storage_type_str)\n                storage_numel = obj._size()\n\n            else:\n                storage = obj\n                storage_dtype = torch.uint8\n                storage_type = normalize_storage_type(type(obj))\n                storage_numel = storage.nbytes()\n\n            # If storage is allocated, ensure that any other saved storages\n            # pointing to the same data all have the same dtype. If storage is\n            # not allocated, don't perform this check\n            if storage.data_ptr() != 0:\n                if storage.data_ptr() in self.storage_dtypes:\n                    if storage_dtype != self.storage_dtypes[storage.data_ptr()]:\n                        raise RuntimeError(\n                            \"Cannot save multiple tensors or storages that view the same data as different types\"\n                        )\n                else:\n                    self.storage_dtypes[storage.data_ptr()] = storage_dtype\n\n            storage_key = self.id_map.get(storage._cdata)\n            if storage_key is None:\n                storage_key = self.saver._write_storage_and_return_key(storage)\n                self.id_map[storage._cdata] = storage_key\n            location = torch.serialization.location_tag(storage)\n\n            return (\"storage\", storage_type, storage_key, location, storage_numel)\n\n        return None\n\n\nclass incremental_save:\n    def __init__(self, name):\n        self.name = name\n        self.zipfile = torch._C.PyTorchFileWriter(str(name))\n        self.has_saved = False\n        self.next_key = 0\n\n    def __enter__(self):\n        return self\n\n    def store_early(self, tensor):\n        if isinstance(tensor, torch.Tensor):\n            return SavingProxyForTensor(tensor, self)\n        raise TypeError(f\"can only store tensors early, not {type(tensor)}\")\n\n    def save(self, obj):\n        if self.has_saved:\n            raise RuntimeError(\"have already saved\")\n        # Write the pickle data for `obj`\n        data_buf = BytesIO()\n        pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)\n        pickler.dump(obj)\n        data_value = data_buf.getvalue()\n        self.zipfile.write_record(\"data.pkl\", data_value, len(data_value))\n        self.has_saved = True\n\n    def _write_storage_and_return_key(self, storage):\n        if self.has_saved:\n            raise RuntimeError(\"have already saved\")\n        key = self.next_key\n        self.next_key += 1\n        name = f\"data/{key}\"\n        if storage.device.type != \"cpu\":\n            storage = storage.cpu()\n        num_bytes = storage.nbytes()\n        self.zipfile.write_record(name, storage.data_ptr(), num_bytes)\n        return key\n\n    def __exit__(self, type, value, traceback):\n        self.zipfile.write_end_of_file()\n\n\nT = TypeVar(\"T\")\n\n\ndef step_csv_logger(*args: Any, cls: Type[T] = CSVLogger, **kwargs: Any) -> T:\n    logger = cls(*args, **kwargs)\n\n    def merge_by(dicts, key):\n        from collections import defaultdict\n\n        out = defaultdict(dict)\n        for d in dicts:\n            if key in d:\n                out[d[key]].update(d)\n        return [v for _, v in sorted(out.items())]\n\n    def save(self) -> None:\n        \"\"\"Overridden to merge CSV by the step number.\"\"\"\n        import csv\n\n        if not self.metrics:\n            return\n        metrics = merge_by(self.metrics, \"step\")\n        keys = sorted({k for m in metrics for k in m})\n        with self._fs.open(self.metrics_file_path, \"w\", newline=\"\") as f:\n            writer = csv.DictWriter(f, fieldnames=keys)\n            writer.writeheader()\n            writer.writerows(metrics)\n\n    logger.experiment.save = MethodType(save, logger.experiment)\n\n    return logger\n\n\ndef chunked_cross_entropy(\n    logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128\n) -> torch.Tensor:\n    # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate\n    # the memory usage in fine-tuning settings with low number of parameters.\n    # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing\n    # the memory spike's magnitude\n\n    # lm_head was chunked (we are fine-tuning)\n    if isinstance(logits, list):\n        # don't want to chunk cross entropy\n        if chunk_size == 0:\n            logits = torch.cat(logits, dim=1)\n            logits = logits.reshape(-1, logits.size(-1))\n            targets = targets.reshape(-1)\n            return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)\n\n        # chunk cross entropy\n        logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits]\n        target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)]\n        loss_chunks = [\n            torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction=\"none\")\n            for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)\n        ]\n        return torch.cat(loss_chunks).mean()\n\n    # no chunking at all\n    logits = logits.reshape(-1, logits.size(-1))\n    targets = targets.reshape(-1)\n    if chunk_size == 0:\n        return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)\n\n    # lm_head wasn't chunked, chunk cross entropy\n    logit_chunks = logits.split(chunk_size)\n    target_chunks = targets.split(chunk_size)\n    loss_chunks = [\n        torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction=\"none\")\n        for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)\n    ]\n    return torch.cat(loss_chunks).mean()\n\n\ndef map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:\n    for checkpoint_name, attribute_name in mapping.items():\n        full_checkpoint_name = prefix + checkpoint_name\n        if full_checkpoint_name in state_dict:\n            full_attribute_name = prefix + attribute_name\n            state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)\n    return state_dict\n\n\ndef get_default_supported_precision(training: bool, tpu: bool = False) -> str:\n    \"\"\"Return default precision that is supported by the hardware.\n\n    Args:\n        training: `-mixed` or `-true` version of the precision to use\n        tpu: whether TPU device is used\n\n    Returns:\n        default precision that is suitable for the task and is supported by the hardware\n    \"\"\"\n    if tpu:\n        return \"32-true\"\n    if not torch.cuda.is_available() or torch.cuda.is_bf16_supported():\n        return \"bf16-mixed\" if training else \"bf16-true\"\n    return \"16-mixed\" if training else \"16-true\"\n"
  },
  {
    "path": "pretrain/tinyllama.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport glob\nimport math\nimport sys\nimport time\nfrom pathlib import Path\nfrom typing import Optional, Tuple, Union\nimport math\nimport lightning as L\nimport torch\nfrom lightning.fabric.strategies import FSDPStrategy, XLAStrategy\nfrom torch.utils.data import DataLoader\nfrom functools import partial\n# support running without installing as a package\nwd = Path(__file__).parent.parent.resolve()\nsys.path.append(str(wd))\n# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually\nfrom lit_gpt.model import GPT, Block, Config, CausalSelfAttention\nfrom lit_gpt.packed_dataset import CombinedDataset, PackedDataset\nfrom lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor\nfrom lit_gpt.speed_monitor import estimate_flops, measure_flops\nfrom lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load\nfrom pytorch_lightning.loggers import WandbLogger\n#from lit_gpt import FusedCrossEntropyLoss\nimport random\n\nmodel_name = 'tiny_LLaMA_135M_2k' # model to train\n\nname = \"tinyllama\"\nout_dir = Path(\"./out\") / (name+\"_135M_2k\")\n\ndefault_seed=3407\n\n# Hyperparameters\nnum_node=8\nnum_of_devices = 8\nglobal_batch_size = 1024/num_node\nlearning_rate = 6e-4\nmicro_batch_size = 16\nnum_epochs=1\nnum_total_token_in_b = 670 * num_epochs\n\nwarmup_steps = 2000\nlog_step_interval = 50\neval_iters = 1000\nsave_step_interval = 2000\neval_step_interval = 2000\n\n\nweight_decay = 1e-1\nbeta1 = 0.9\nbeta2 = 0.95\ngrad_clip = 1.0\ndecay_lr = True\nmin_lr = 6e-5\n\nbatch_size = global_batch_size // num_of_devices\ngradient_accumulation_steps = math.ceil(batch_size / micro_batch_size)\nactual_global_batch = gradient_accumulation_steps*micro_batch_size*num_of_devices*num_node\nprint(actual_global_batch)\nmax_step = int(num_total_token_in_b * 10**9/(actual_global_batch*2048)//save_step_interval + 1)*save_step_interval\nassert gradient_accumulation_steps > 0\nwarmup_iters = warmup_steps * gradient_accumulation_steps\n\n\n\nimport math\nmax_iters = max_step * gradient_accumulation_steps\nlr_decay_iters = max_iters\nlog_iter_interval = math.ceil(log_step_interval * gradient_accumulation_steps)\n\n\n\n# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight.\ntrain_data_config = [\n    (\"train_\", 1.0)\n]\n\nval_data_config = [\n    (\"validation_slim\", 1.0),\n]\n\nhparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith(\"_\")}\nlogger = step_csv_logger(\"out\", name, flush_logs_every_n_steps=log_iter_interval)\nwandb_logger = WandbLogger()\n\n\ndef setup(\n    devices: int = 8,\n    train_data_dir: Path = Path(\"./slim_star_combined\"),\n    val_data_dir: Optional[Path] = None,\n    precision: Optional[str] = 'bf16-mixed',\n    tpu: bool = False,\n    resume: Union[bool, Path] = False,\n    model_name: str=None\n) -> None:\n    precision = precision or get_default_supported_precision(training=True, tpu=tpu)\n\n    if devices > 1:\n        if tpu:\n            # For multi-host TPU training, the device count for Fabric is limited to the count on a single host.\n            devices = \"auto\"\n            strategy = XLAStrategy(sync_module_states=False)\n        else:\n            strategy = FSDPStrategy(\n                auto_wrap_policy={Block},\n                activation_checkpointing_policy=None,\n                state_dict_type=\"full\",\n                limit_all_gathers=True,\n                cpu_offload=False,\n            )\n    else:\n        strategy = \"auto\"\n\n    fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger])\n    fabric.print(hparams)\n    #fabric.launch(main, train_data_dir, val_data_dir, resume)\n    main(fabric, train_data_dir, val_data_dir, resume, model_name)\n\n\ndef main(fabric, train_data_dir, val_data_dir, resume, model_name=None):\n    monitor = Monitor(fabric, window_size=2, time_unit=\"seconds\", log_iter_interval=log_iter_interval)\n\n    if fabric.global_rank == 0:\n        out_dir.mkdir(parents=True, exist_ok=True)\n\n    config = Config.from_name(model_name)\n\n    train_dataloader, val_dataloader = create_dataloaders(\n        batch_size=micro_batch_size,\n        block_size=config.block_size,\n        fabric=fabric,\n        train_data_dir=train_data_dir,\n        val_data_dir=val_data_dir,\n        seed=default_seed,\n    )\n    if val_dataloader is None:\n        train_dataloader = fabric.setup_dataloaders(train_dataloader)\n    else:\n        train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)\n\n    fabric.seed_everything(default_seed)  # same seed for every process to init model (FSDP)\n\n    fabric.print(f\"Loading model with {config.__dict__}\")\n    t0 = time.perf_counter()\n    with fabric.init_module(empty_init=False):\n        model = GPT(config)\n        model.apply(partial(model._init_weights ,n_layer=config.n_layer))\n \n\n    fabric.print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\")\n    fabric.print(f\"Total parameters {num_parameters(model):,}\")\n\n    model = fabric.setup(model)\n    optimizer = torch.optim.AdamW(\n        model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False\n    )\n    # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True)\n    optimizer = fabric.setup_optimizers(optimizer)\n\n    state = {\"model\": model, \"optimizer\": optimizer, \"hparams\": hparams, \"iter_num\": 0, \"step_count\": 0}\n\n    if resume is True:\n        resume = sorted(out_dir.glob(\"*.pth\"))[-1]\n    if resume :\n        fabric.print(f\"Resuming training from {resume}\")\n        fabric.load(resume, state)\n\n    train_time = time.perf_counter()\n    train(fabric, state, train_dataloader, val_dataloader, monitor, resume)\n    fabric.print(f\"Training time: {(time.perf_counter()-train_time):.2f}s\")\n    if fabric.device.type == \"cuda\":\n        fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\")\n\n\ndef train(fabric, state, train_dataloader, val_dataloader, monitor, resume):\n    model = state[\"model\"]\n    optimizer = state[\"optimizer\"]\n\n    # if val_dataloader is not None:\n    #     validate(fabric, model, val_dataloader)  # sanity check\n    model.train()\n    \n    meta_model = GPT(model.config).cuda()\n    # \"estimated\" is not as precise as \"measured\". Estimated is optimistic but widely used in the wild.\n    # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs,\n    # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead\n    estimated_flops = estimate_flops(meta_model) * micro_batch_size\n    fabric.print(f\"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}\")\n    x = torch.randint(0, 1, (micro_batch_size, model.config.block_size))\n    # measured_flos run in meta. Will trigger fusedRMSNorm error\n    #measured_flops = measure_flops(meta_model, x)\n    #fabric.print(f\"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}\")\n    del meta_model, x\n\n    total_lengths = 0\n    total_t0 = time.perf_counter()\n\n    if fabric.device.type == \"xla\":\n        import torch_xla.core.xla_model as xm\n\n        xm.mark_step()\n    \n    \n    initial_iter = state[\"iter_num\"]\n    curr_iter = 0\n            \n    loss_func = torch.nn.CrossEntropyLoss() #FusedCrossEntropyLoss()\n    for i, train_data in enumerate(train_dataloader):\n        # resume loader state. This is not elegant but it works. Should rewrite it in the future.\n        if resume:\n            if curr_iter < initial_iter:\n                curr_iter += 1\n                continue\n            else:\n                resume = False\n                curr_iter = -1\n                fabric.barrier()\n                fabric.print(\"resume finished, taken {} seconds\".format(time.perf_counter() - total_t0))\n        if state[\"iter_num\"] >= max_iters:\n            break\n        \n        # determine and set the learning rate for this iteration\n        lr = get_lr(state[\"iter_num\"]) if decay_lr else learning_rate\n        for param_group in optimizer.param_groups:\n            param_group[\"lr\"] = lr\n\n        iter_t0 = time.perf_counter()\n\n        input_ids = train_data[:, 0 : model.config.block_size].contiguous()\n        targets = train_data[:, 1 : model.config.block_size + 1].contiguous()\n        is_accumulating = (state[\"iter_num\"] + 1) % gradient_accumulation_steps != 0\n        with fabric.no_backward_sync(model, enabled=is_accumulating):\n            logits = model(input_ids)\n            loss = loss_func(logits.transpose(1,2), targets)\n            # loss = chunked_cross_entropy(logits, targets, chunk_size=0)\n            fabric.backward(loss / gradient_accumulation_steps)\n\n        if not is_accumulating:\n            fabric.clip_gradients(model, optimizer, max_norm=grad_clip)\n            optimizer.step()\n            optimizer.zero_grad()\n            state[\"step_count\"] += 1\n        elif fabric.device.type == \"xla\":\n            xm.mark_step()\n        state[\"iter_num\"] += 1\n        # input_id: B L \n        total_lengths += input_ids.size(1)\n        t1 = time.perf_counter()\n        if i % log_step_interval == 0:\n            fabric.print(\n                f\"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:\"                   f\" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}\"\n                f\" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. \" \n                # print days as well\n                f\" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. \"\n            )\n \n        monitor.on_train_batch_end(\n            state[\"iter_num\"] * micro_batch_size,\n            t1 - total_t0,\n            # this assumes that device FLOPs are the same and that all devices have the same batch size\n            fabric.world_size,\n            state[\"step_count\"],\n            flops_per_batch=estimated_flops,\n            lengths=total_lengths,\n            train_loss = loss.item()\n        )\n\n            \n            \n            \n        if val_dataloader is not None and not is_accumulating and state[\"step_count\"] % eval_step_interval == 0:\n            \n            t0 = time.perf_counter()\n            val_loss = validate(fabric, model, val_dataloader)\n            t1 = time.perf_counter() - t0\n            monitor.eval_end(t1)\n            fabric.print(f\"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms\")\n            fabric.log_dict({\"metric/val_loss\": val_loss.item(), \"total_tokens\": model.config.block_size * (state[\"iter_num\"] + 1) * micro_batch_size * fabric.world_size}, state[\"step_count\"])\n            fabric.log_dict({\"metric/val_ppl\": math.exp(val_loss.item()), \"total_tokens\": model.config.block_size * (state[\"iter_num\"] + 1) * micro_batch_size * fabric.world_size}, state[\"step_count\"])\n            fabric.barrier()\n        if not is_accumulating and state[\"step_count\"] % save_step_interval == 0:\n            checkpoint_path = out_dir / f\"iter-{state['iter_num']:06d}-ckpt.pth\"\n            fabric.print(f\"Saving checkpoint to {str(checkpoint_path)!r}\")\n            # only works for pytorch>=2.0\n            fabric.save(checkpoint_path, state)\n\n        \n@torch.no_grad()\ndef validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor:\n    fabric.print(\"Validating ...\")\n    model.eval()\n\n    losses = torch.zeros(eval_iters, device=fabric.device)\n    for k, val_data in enumerate(val_dataloader):\n        if k >= eval_iters:\n            break\n        input_ids = val_data[:, 0 : model.config.block_size].contiguous()\n        targets = val_data[:, 1 : model.config.block_size + 1].contiguous()\n        logits = model(input_ids)\n        loss = chunked_cross_entropy(logits, targets, chunk_size=0)\n\n        # loss_func = FusedCrossEntropyLoss()\n        # loss = loss_func(logits, targets)\n        losses[k] = loss.item()\n        \n    out = losses.mean()\n\n    model.train()\n    return out\n\n\ndef create_dataloader(\n    batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split=\"train\"\n) -> DataLoader:\n    datasets = []\n    data_config = train_data_config if split == \"train\" else val_data_config\n    for prefix, _ in data_config:\n        filenames = sorted(glob.glob(str(data_dir / f\"{prefix}*\")))\n        random.seed(seed)\n        random.shuffle(filenames)\n\n        dataset = PackedDataset(\n            filenames,\n            # n_chunks control the buffer size. \n            # Note that the buffer size also impacts the random shuffle\n            # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer)\n            n_chunks=128,\n            block_size=block_size,\n            shuffle=shuffle,\n            seed=seed+fabric.global_rank,\n            num_processes=fabric.world_size,\n            process_rank=fabric.global_rank,\n        )\n        datasets.append(dataset)\n\n    if not datasets:\n        raise RuntimeError(\n            f\"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset.\"\n        )\n\n    weights = [weight for _, weight in data_config]\n    sum_weights = sum(weights)\n    weights = [el / sum_weights for el in weights]\n\n    combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights)\n\n    return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)\n\n\ndef create_dataloaders(\n    batch_size: int,\n    block_size: int,\n    fabric,\n    train_data_dir: Path = Path(\"data/redpajama_sample\"),\n    val_data_dir: Optional[Path] = None,\n    seed: int = 12345,\n) -> Tuple[DataLoader, DataLoader]:\n    # Increase by one because we need the next word as well\n    effective_block_size = block_size + 1\n    train_dataloader = create_dataloader(\n        batch_size=batch_size,\n        block_size=effective_block_size,\n        fabric=fabric,\n        data_dir=train_data_dir,\n        shuffle=True,\n        seed=seed,\n        split=\"train\"\n    )\n    val_dataloader = (\n        create_dataloader(\n            batch_size=batch_size,\n            block_size=effective_block_size,\n            fabric=fabric,\n            data_dir=val_data_dir,\n            shuffle=False,\n            seed=seed,\n            split=\"validation\"\n        )\n        if val_data_dir\n        else None\n    )\n    return train_dataloader, val_dataloader\n\n\n# learning rate decay scheduler (cosine with warmup)\ndef get_lr(it):\n    # 1) linear warmup for warmup_iters steps\n    if it < warmup_iters:\n        return learning_rate * it / warmup_iters\n    # 2) if it > lr_decay_iters, return min learning rate\n    if it > lr_decay_iters:\n        return min_lr\n    # 3) in between, use cosine decay down to min learning rate\n    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)\n    assert 0 <= decay_ratio <= 1\n    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1\n    return min_lr + coeff * (learning_rate - min_lr)\n\n\nif __name__ == \"__main__\":\n    # Uncomment this line if you see an error: \"Expected is_sm80 to be true, but got false\"\n    # torch.backends.cuda.enable_flash_sdp(False)\n    torch.set_float32_matmul_precision(\"high\")\n\n    from jsonargparse import CLI\n\n    CLI(setup)\n"
  },
  {
    "path": "pretrain/tinyllama_code.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport glob\nimport math\nimport sys\nimport time\nfrom pathlib import Path\nfrom typing import Optional, Tuple, Union\nimport math\nimport lightning as L\nimport torch\nfrom lightning.fabric.strategies import FSDPStrategy, XLAStrategy\nfrom torch.utils.data import DataLoader\nfrom functools import partial\n# support running without installing as a package\nwd = Path(__file__).parent.parent.resolve()\nsys.path.append(str(wd))\n# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually\nfrom lit_gpt.model import GPT, Block, Config, CausalSelfAttention\nfrom lit_gpt.packed_dataset import CombinedDataset, PackedDataset\nfrom lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor\nfrom lit_gpt.speed_monitor import estimate_flops, measure_flops\nfrom lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load\nfrom pytorch_lightning.loggers import WandbLogger\n#from lit_gpt import FusedCrossEntropyLoss\nimport random\n\nmodel_name = 'tiny_LLaMA_135M_2k' # model to train\n\nname = \"tinyllama\"\nout_dir = Path(\"./out\") / (name+\"_135M_2k_code\")\n\ndefault_seed=3407\n\n# Hyperparameters\nnum_node=1\nnum_of_devices = 8\nglobal_batch_size = 320/num_node\nlearning_rate = 3e-4\nmicro_batch_size = 16\nnum_epochs=1\nnum_total_token_in_b = 21 * num_epochs\n\nwarmup_steps = 2000\nlog_step_interval = 10\neval_iters = 1000\nsave_step_interval = 2000\neval_step_interval = 2000\n\n\nweight_decay = 1e-1\nbeta1 = 0.9\nbeta2 = 0.95\ngrad_clip = 1.0\ndecay_lr = True\nmin_lr = 3e-5\n\nbatch_size = global_batch_size // num_of_devices\ngradient_accumulation_steps = batch_size // micro_batch_size\nactual_global_batch = gradient_accumulation_steps*micro_batch_size*num_of_devices*num_node\nprint(actual_global_batch)\nmax_step = int(num_total_token_in_b * 10**9/(actual_global_batch*2048)//save_step_interval + 1)*save_step_interval\nassert gradient_accumulation_steps > 0\nwarmup_iters = warmup_steps * gradient_accumulation_steps\n\n\n\nimport math\nmax_iters = max_step * gradient_accumulation_steps\nlr_decay_iters = max_iters\nlog_iter_interval = math.ceil(log_step_interval * gradient_accumulation_steps)\n\n\n\n# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight.\ntrain_data_config = [\n    (\"train_\", 1.0),\n]\n\nval_data_config = [\n    (\"validation_slim\", 1.0),\n]\n\nhparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith(\"_\")}\nlogger = step_csv_logger(\"out\", name, flush_logs_every_n_steps=log_iter_interval)\nwandb_logger = WandbLogger()\n\n\ndef setup(\n    devices: int = 8,\n    train_data_dir: Path = Path(\"./slim_star_combined\"),\n    val_data_dir: Optional[Path] = None,\n    precision: Optional[str] = 'bf16-mixed',\n    tpu: bool = False,\n    resume: Union[bool, Path] = False,\n    model_name: str=None,\n    checkpoint_path: str=None\n) -> None:\n    precision = precision# or get_default_supported_precision(training=True, tpu=tpu)\n\n    if devices > 1:\n        if tpu:\n            # For multi-host TPU training, the device count for Fabric is limited to the count on a single host.\n            devices = \"auto\"\n            strategy = XLAStrategy(sync_module_states=False)\n        else:\n            strategy = FSDPStrategy(\n                auto_wrap_policy={Block},\n                activation_checkpointing_policy=None,\n                state_dict_type=\"full\",\n                limit_all_gathers=True,\n                cpu_offload=False,\n            )\n    else:\n        strategy = \"auto\"\n\n    fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger])\n    fabric.print(hparams)\n    #fabric.launch(main, train_data_dir, val_data_dir, resume)\n    main(fabric, train_data_dir, val_data_dir, resume, model_name, checkpoint_path)\n\n\ndef main(fabric, train_data_dir, val_data_dir, resume, model_name=None, checkpoint_path=None):\n    print('continue {}'.format(checkpoint_path))\n    monitor = Monitor(fabric, window_size=2, time_unit=\"seconds\", log_iter_interval=log_iter_interval)\n\n    if fabric.global_rank == 0:\n        out_dir.mkdir(parents=True, exist_ok=True)\n\n    config = Config.from_name(model_name)\n\n    train_dataloader, val_dataloader = create_dataloaders(\n        batch_size=micro_batch_size,\n        block_size=config.block_size,\n        fabric=fabric,\n        train_data_dir=train_data_dir,\n        val_data_dir=val_data_dir,\n        seed=default_seed,\n    )\n    if val_dataloader is None:\n        train_dataloader = fabric.setup_dataloaders(train_dataloader)\n    else:\n        train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)\n\n    fabric.seed_everything(default_seed)  # same seed for every process to init model (FSDP)\n\n    fabric.print(f\"Loading model with {config.__dict__}\")\n    t0 = time.perf_counter()\n    with fabric.init_module(empty_init=True):\n        model = GPT(config)\n \n\n    fabric.print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\")\n    fabric.print(f\"Total parameters {num_parameters(model):,}\")\n\n    model = fabric.setup(model)\n    fabric.load_raw(checkpoint_path, model, strict=True)\n    optimizer = torch.optim.AdamW(\n        model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False\n    )\n    # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True)\n    optimizer = fabric.setup_optimizers(optimizer)\n\n    state = {\"model\": model, \"optimizer\": optimizer, \"hparams\": hparams, \"iter_num\": 0, \"step_count\": 0}\n\n    if resume is True:\n        resume = sorted(out_dir.glob(\"*.pth\"))[-1]\n    if resume :\n        fabric.print(f\"Resuming training from {resume}\")\n        fabric.load(resume, state)\n\n    train_time = time.perf_counter()\n    train(fabric, state, train_dataloader, val_dataloader, monitor, resume)\n    fabric.print(f\"Training time: {(time.perf_counter()-train_time):.2f}s\")\n    if fabric.device.type == \"cuda\":\n        fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\")\n\n\ndef train(fabric, state, train_dataloader, val_dataloader, monitor, resume):\n    model = state[\"model\"]\n    optimizer = state[\"optimizer\"]\n\n    # if val_dataloader is not None:\n    #     validate(fabric, model, val_dataloader)  # sanity check\n    model.train()\n    \n    meta_model = GPT(model.config).cuda()\n    # \"estimated\" is not as precise as \"measured\". Estimated is optimistic but widely used in the wild.\n    # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs,\n    # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead\n    estimated_flops = estimate_flops(meta_model) * micro_batch_size\n    fabric.print(f\"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}\")\n    x = torch.randint(0, 1, (micro_batch_size, model.config.block_size))\n    # measured_flos run in meta. Will trigger fusedRMSNorm error\n    #measured_flops = measure_flops(meta_model, x)\n    #fabric.print(f\"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}\")\n    del meta_model, x\n\n    total_lengths = 0\n    total_t0 = time.perf_counter()\n\n    if fabric.device.type == \"xla\":\n        import torch_xla.core.xla_model as xm\n\n        xm.mark_step()\n    \n    \n    initial_iter = state[\"iter_num\"]\n    curr_iter = 0\n            \n    loss_func = torch.nn.CrossEntropyLoss() #FusedCrossEntropyLoss()\n    for i, train_data in enumerate(train_dataloader):\n        # resume loader state. This is not elegant but it works. Should rewrite it in the future.\n        if resume:\n            if curr_iter < initial_iter:\n                curr_iter += 1\n                continue\n            else:\n                resume = False\n                curr_iter = -1\n                fabric.barrier()\n                fabric.print(\"resume finished, taken {} seconds\".format(time.perf_counter() - total_t0))\n        if state[\"iter_num\"] >= max_iters:\n            break\n        \n        # determine and set the learning rate for this iteration\n        lr = get_lr(state[\"iter_num\"]) if decay_lr else learning_rate\n        for param_group in optimizer.param_groups:\n            param_group[\"lr\"] = lr\n\n        iter_t0 = time.perf_counter()\n\n        input_ids = train_data[:, 0 : model.config.block_size].contiguous()\n        targets = train_data[:, 1 : model.config.block_size + 1].contiguous()\n        is_accumulating = (state[\"iter_num\"] + 1) % gradient_accumulation_steps != 0\n        with fabric.no_backward_sync(model, enabled=is_accumulating):\n            logits = model(input_ids)\n            loss = loss_func(logits.transpose(1,2), targets)\n            # loss = chunked_cross_entropy(logits, targets, chunk_size=0)\n            fabric.backward(loss / gradient_accumulation_steps)\n\n        if not is_accumulating:\n            fabric.clip_gradients(model, optimizer, max_norm=grad_clip)\n            optimizer.step()\n            optimizer.zero_grad()\n            state[\"step_count\"] += 1\n        elif fabric.device.type == \"xla\":\n            xm.mark_step()\n        state[\"iter_num\"] += 1\n        # input_id: B L \n        total_lengths += input_ids.size(1)\n        t1 = time.perf_counter()\n        if i % log_step_interval == 0:\n            fabric.print(\n                f\"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:\"                   f\" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}\"\n                f\" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. \" \n                # print days as well\n                f\" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. \"\n            )\n \n        monitor.on_train_batch_end(\n            state[\"iter_num\"] * micro_batch_size,\n            t1 - total_t0,\n            # this assumes that device FLOPs are the same and that all devices have the same batch size\n            fabric.world_size,\n            state[\"step_count\"],\n            flops_per_batch=estimated_flops,\n            lengths=total_lengths,\n            train_loss = loss.item()\n        )\n\n            \n            \n            \n        if val_dataloader is not None and not is_accumulating and state[\"step_count\"] % eval_step_interval == 0:\n            \n            t0 = time.perf_counter()\n            val_loss = validate(fabric, model, val_dataloader)\n            t1 = time.perf_counter() - t0\n            monitor.eval_end(t1)\n            fabric.print(f\"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms\")\n            fabric.log_dict({\"metric/val_loss\": val_loss.item(), \"total_tokens\": model.config.block_size * (state[\"iter_num\"] + 1) * micro_batch_size * fabric.world_size}, state[\"step_count\"])\n            fabric.log_dict({\"metric/val_ppl\": math.exp(val_loss.item()), \"total_tokens\": model.config.block_size * (state[\"iter_num\"] + 1) * micro_batch_size * fabric.world_size}, state[\"step_count\"])\n            fabric.barrier()\n        if not is_accumulating and state[\"step_count\"] % save_step_interval == 0:\n            checkpoint_path = out_dir / f\"iter-{state['iter_num']:06d}-ckpt.pth\"\n            fabric.print(f\"Saving checkpoint to {str(checkpoint_path)!r}\")\n            # only works for pytorch>=2.0\n            fabric.save(checkpoint_path, state)\n\n        \n@torch.no_grad()\ndef validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor:\n    fabric.print(\"Validating ...\")\n    model.eval()\n\n    losses = torch.zeros(eval_iters, device=fabric.device)\n    for k, val_data in enumerate(val_dataloader):\n        if k >= eval_iters:\n            break\n        input_ids = val_data[:, 0 : model.config.block_size].contiguous()\n        targets = val_data[:, 1 : model.config.block_size + 1].contiguous()\n        logits = model(input_ids)\n        loss = chunked_cross_entropy(logits, targets, chunk_size=0)\n\n        # loss_func = FusedCrossEntropyLoss()\n        # loss = loss_func(logits, targets)\n        losses[k] = loss.item()\n        \n    out = losses.mean()\n\n    model.train()\n    return out\n\n\ndef create_dataloader(\n    batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split=\"train\"\n) -> DataLoader:\n    datasets = []\n    data_config = train_data_config if split == \"train\" else val_data_config\n    for prefix, _ in data_config:\n        filenames = sorted(glob.glob(str(data_dir / f\"{prefix}*\")))\n        random.seed(seed)\n        random.shuffle(filenames)\n\n        dataset = PackedDataset(\n            filenames,\n            # n_chunks control the buffer size. \n            # Note that the buffer size also impacts the random shuffle\n            # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer)\n            n_chunks=64,\n            block_size=block_size,\n            shuffle=shuffle,\n            seed=seed+fabric.global_rank,\n            num_processes=fabric.world_size,\n            process_rank=fabric.global_rank,\n        )\n        datasets.append(dataset)\n\n    if not datasets:\n        raise RuntimeError(\n            f\"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset.\"\n        )\n\n    weights = [weight for _, weight in data_config]\n    sum_weights = sum(weights)\n    weights = [el / sum_weights for el in weights]\n\n    combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights)\n\n    return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)\n\n\ndef create_dataloaders(\n    batch_size: int,\n    block_size: int,\n    fabric,\n    train_data_dir: Path = Path(\"data/redpajama_sample\"),\n    val_data_dir: Optional[Path] = None,\n    seed: int = 12345,\n) -> Tuple[DataLoader, DataLoader]:\n    # Increase by one because we need the next word as well\n    effective_block_size = block_size + 1\n    train_dataloader = create_dataloader(\n        batch_size=batch_size,\n        block_size=effective_block_size,\n        fabric=fabric,\n        data_dir=train_data_dir,\n        shuffle=True,\n        seed=seed,\n        split=\"train\"\n    )\n    val_dataloader = (\n        create_dataloader(\n            batch_size=batch_size,\n            block_size=effective_block_size,\n            fabric=fabric,\n            data_dir=val_data_dir,\n            shuffle=False,\n            seed=seed,\n            split=\"validation\"\n        )\n        if val_data_dir\n        else None\n    )\n    return train_dataloader, val_dataloader\n\n\n# learning rate decay scheduler (cosine with warmup)\ndef get_lr(it):\n    # 1) linear warmup for warmup_iters steps\n    if it < warmup_iters:\n        return learning_rate * it / warmup_iters\n    # 2) if it > lr_decay_iters, return min learning rate\n    if it > lr_decay_iters:\n        return min_lr\n    # 3) in between, use cosine decay down to min learning rate\n    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)\n    assert 0 <= decay_ratio <= 1\n    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1\n    return min_lr + coeff * (learning_rate - min_lr)\n\n\nif __name__ == \"__main__\":\n    # Uncomment this line if you see an error: \"Expected is_sm80 to be true, but got false\"\n    # torch.backends.cuda.enable_flash_sdp(False)\n    torch.set_float32_matmul_precision(\"high\")\n\n    from jsonargparse import CLI\n\n    CLI(setup)\n"
  },
  {
    "path": "requirement.txt",
    "content": "arrow==1.3.0\nboto3==1.19.12\nfilelock==3.12.4\nlightning==2.1.2\nlightning-cloud==0.5.52\nlightning-utilities==0.10.0\nmarkdown-it-py==3.0.0\npydantic==2.5.2\npydantic_core==2.14.5\npytorch-lightning==2.1.2\nsentencepiece==0.1.99\nwandb==0.15.3\nzstandard==0.22.0\ntransformers==4.37.2\nnumpy==1.22.4\njsonargparse==4.32.0\nbackoff==2.2.1\nbeautifulsoup4==4.12.3\nblessed==1.20.0\ncroniter==1.4.1\ndateutils==0.6.12\ndeepdiff==6.7.1\neditor==1.6.6\ninquirer==3.4.0\nitsdangerous==2.2.0\nordered-set==4.1.0\nreadchar==4.2.0\nruns==1.2.2\nsoupsieve==2.6\nstarsessions==1.3.0\ntraitlets==5.14.3\nwcwidth==0.2.13\nwebsockets==11.0.3\nxmod==1.8.1\ndatasets"
  },
  {
    "path": "scripts/convert_hf_checkpoint.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport contextlib\nimport gc\nimport json\nimport sys\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Dict, List, Literal, Optional, Tuple, Union\n\nimport torch\n\n# support running without installing as a package\nwd = Path(__file__).parent.parent.resolve()\nsys.path.append(str(wd))\n\nfrom lit_gpt import Config\nfrom lit_gpt.utils import NotYetLoadedTensor, incremental_save, lazy_load\n\n\ndef copy_weights_gpt_neox(\n    state_dict: Dict[str, torch.Tensor],\n    hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n    dtype: Optional[torch.dtype] = None,\n) -> None:\n    weight_map = {\n        \"gpt_neox.embed_in.weight\": \"transformer.wte.weight\",\n        \"gpt_neox.layers.{}.input_layernorm.bias\": \"transformer.h.{}.norm_1.bias\",\n        \"gpt_neox.layers.{}.input_layernorm.weight\": \"transformer.h.{}.norm_1.weight\",\n        \"gpt_neox.layers.{}.attention.query_key_value.bias\": \"transformer.h.{}.attn.attn.bias\",\n        \"gpt_neox.layers.{}.attention.query_key_value.weight\": \"transformer.h.{}.attn.attn.weight\",\n        \"gpt_neox.layers.{}.attention.dense.bias\": \"transformer.h.{}.attn.proj.bias\",\n        \"gpt_neox.layers.{}.attention.dense.weight\": \"transformer.h.{}.attn.proj.weight\",\n        \"gpt_neox.layers.{}.attention.rotary_emb.inv_freq\": None,\n        \"gpt_neox.layers.{}.attention.bias\": None,\n        \"gpt_neox.layers.{}.attention.masked_bias\": None,\n        \"gpt_neox.layers.{}.post_attention_layernorm.bias\": \"transformer.h.{}.norm_2.bias\",\n        \"gpt_neox.layers.{}.post_attention_layernorm.weight\": \"transformer.h.{}.norm_2.weight\",\n        \"gpt_neox.layers.{}.mlp.dense_h_to_4h.bias\": \"transformer.h.{}.mlp.fc.bias\",\n        \"gpt_neox.layers.{}.mlp.dense_h_to_4h.weight\": \"transformer.h.{}.mlp.fc.weight\",\n        \"gpt_neox.layers.{}.mlp.dense_4h_to_h.bias\": \"transformer.h.{}.mlp.proj.bias\",\n        \"gpt_neox.layers.{}.mlp.dense_4h_to_h.weight\": \"transformer.h.{}.mlp.proj.weight\",\n        \"gpt_neox.final_layer_norm.bias\": \"transformer.ln_f.bias\",\n        \"gpt_neox.final_layer_norm.weight\": \"transformer.ln_f.weight\",\n        \"embed_out.weight\": \"lm_head.weight\",\n    }\n\n    for name, param in hf_weights.items():\n        if \"gpt_neox.layers\" in name:\n            from_name, number = layer_template(name, 2)\n            to_name = weight_map[from_name]\n            if to_name is None:\n                continue\n            to_name = to_name.format(number)\n        else:\n            to_name = weight_map[name]\n        param = load_param(param, name, dtype)\n        if saver is not None:\n            param = saver.store_early(param)\n        state_dict[to_name] = param\n\n\ndef copy_weights_falcon(\n    size: Literal[\"7b\", \"40b\"],\n    state_dict: Dict[str, torch.Tensor],\n    hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n    dtype: Optional[torch.dtype] = None,\n) -> None:\n    weight_map = {\n        \"transformer.word_embeddings.weight\": \"transformer.wte.weight\",\n        \"transformer.h.{}.self_attention.query_key_value.weight\": \"transformer.h.{}.attn.attn.weight\",\n        \"transformer.h.{}.self_attention.dense.weight\": \"transformer.h.{}.attn.proj.weight\",\n        \"transformer.h.{}.mlp.dense_h_to_4h.weight\": \"transformer.h.{}.mlp.fc.weight\",\n        \"transformer.h.{}.mlp.dense_4h_to_h.weight\": \"transformer.h.{}.mlp.proj.weight\",\n        \"transformer.ln_f.bias\": \"transformer.ln_f.bias\",\n        \"transformer.ln_f.weight\": \"transformer.ln_f.weight\",\n        \"lm_head.weight\": \"lm_head.weight\",\n    }\n    # the original model definition is different for each size\n    if size == \"7b\":\n        weight_map.update(\n            {\n                \"transformer.h.{}.input_layernorm.bias\": \"transformer.h.{}.norm_1.bias\",\n                \"transformer.h.{}.input_layernorm.weight\": \"transformer.h.{}.norm_1.weight\",\n            }\n        )\n    elif size == \"40b\":\n        weight_map.update(\n            {\n                \"transformer.h.{}.ln_attn.bias\": \"transformer.h.{}.norm_1.bias\",\n                \"transformer.h.{}.ln_attn.weight\": \"transformer.h.{}.norm_1.weight\",\n                \"transformer.h.{}.ln_mlp.bias\": \"transformer.h.{}.norm_2.bias\",\n                \"transformer.h.{}.ln_mlp.weight\": \"transformer.h.{}.norm_2.weight\",\n            }\n        )\n    else:\n        raise NotImplementedError\n\n    for name, param in hf_weights.items():\n        if \"transformer.h\" in name:\n            from_name, number = layer_template(name, 2)\n            to_name = weight_map[from_name].format(number)\n        else:\n            to_name = weight_map[name]\n        param = load_param(param, name, dtype)\n        if saver is not None:\n            param = saver.store_early(param)\n        state_dict[to_name] = param\n\n\ndef copy_weights_hf_llama(\n    config: Config,\n    qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],\n    state_dict: Dict[str, torch.Tensor],\n    hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n    dtype: Optional[torch.dtype] = None,\n) -> None:\n    weight_map = {\n        \"model.embed_tokens.weight\": \"transformer.wte.weight\",\n        \"model.layers.{}.input_layernorm.weight\": \"transformer.h.{}.norm_1.weight\",\n        \"model.layers.{}.self_attn.q_proj.weight\": None,\n        \"model.layers.{}.self_attn.k_proj.weight\": None,\n        \"model.layers.{}.self_attn.v_proj.weight\": None,\n        \"model.layers.{}.self_attn.o_proj.weight\": \"transformer.h.{}.attn.proj.weight\",\n        \"model.layers.{}.self_attn.rotary_emb.inv_freq\": None,\n        \"model.layers.{}.post_attention_layernorm.weight\": \"transformer.h.{}.norm_2.weight\",\n        \"model.layers.{}.mlp.gate_proj.weight\": \"transformer.h.{}.mlp.swiglu.w1.weight\",\n        \"model.layers.{}.mlp.up_proj.weight\": \"transformer.h.{}.mlp.swiglu.w2.weight\",\n        \"model.layers.{}.mlp.down_proj.weight\": \"transformer.h.{}.mlp.swiglu.w3.weight\",\n        \"model.norm.weight\": \"transformer.ln_f.weight\",\n        \"lm_head.weight\": \"lm_head.weight\",\n    }\n\n    for name, param in hf_weights.items():\n        if \"model.layers\" in name:\n            from_name, number = layer_template(name, 2)\n            qkv = qkv_weights.setdefault(number, [None, None, None])\n            if \"q_proj\" in name:\n                qkv[0] = param\n            elif \"k_proj\" in name:\n                qkv[1] = param\n            elif \"v_proj\" in name:\n                qkv[2] = param\n            to_name = weight_map[from_name]\n            if to_name is None:\n                continue\n            to_name = to_name.format(number)\n        else:\n            to_name = weight_map[name]\n        param = load_param(param, name, dtype)\n        if saver is not None:\n            param = saver.store_early(param)\n        state_dict[to_name] = param\n\n    for i, (q, k, v) in list(qkv_weights.items()):\n        if q is None or k is None or v is None:\n            # split across different .bin files\n            continue\n        q = load_param(q, f\"layer {i} q\", dtype)\n        k = load_param(k, f\"layer {i} k\", dtype)\n        v = load_param(v, f\"layer {i} v\", dtype)\n        q_per_kv = config.n_head // config.n_query_groups\n        qs = torch.split(q, config.head_size * q_per_kv)\n        ks = torch.split(k, config.head_size)\n        vs = torch.split(v, config.head_size)\n        cycled = [t for group in zip(qs, ks, vs) for t in group]\n        qkv = torch.cat(cycled)\n        state_dict[f\"transformer.h.{i}.attn.attn.weight\"] = qkv\n        del qkv_weights[i]\n\n\ndef layer_template(layer_name: str, idx: int) -> Tuple[str, int]:\n    split = layer_name.split(\".\")\n    number = int(split[idx])\n    split[idx] = \"{}\"\n    from_name = \".\".join(split)\n    return from_name, number\n\n\ndef load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor:\n    if hasattr(param, \"_load_tensor\"):\n        # support tensors loaded via `lazy_load()`\n        print(f\"Loading {name!r} into RAM\")\n        param = param._load_tensor()\n    if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype:\n        print(f\"Converting {name!r} from {param.dtype} to {dtype}\")\n        param = param.to(dtype)\n    return param\n\n\n@torch.inference_mode()\ndef convert_hf_checkpoint(\n    *,\n    checkpoint_dir: Path = Path(\"checkpoints/stabilityai/stablelm-base-alpha-3b\"),\n    model_name: Optional[str] = None,\n    dtype: Optional[str] = None,\n) -> None:\n    if model_name is None:\n        model_name = checkpoint_dir.name\n    if dtype is not None:\n        dtype = getattr(torch, dtype)\n\n    config = Config.from_name(model_name)\n    print(f\"Model config {config.__dict__}\")\n    with open(checkpoint_dir / \"lit_config.json\", \"w\") as json_config:\n        json.dump(config.__dict__, json_config)\n\n    if \"falcon\" in model_name:\n        copy_fn = partial(copy_weights_falcon, \"40b\" if config.n_embd == 8192 else \"7b\")\n    elif config._mlp_class == \"LLaMAMLP\":\n        # holder to reconstitute the split q, k, v\n        qkv_weights = {}\n        copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)\n    else:\n        copy_fn = copy_weights_gpt_neox\n\n    # initialize a new empty state dict to hold our new weights\n    sd = {}\n\n    # Load the json file containing weight mapping\n    pytorch_bin_map_json_path = checkpoint_dir / \"pytorch_model.bin.index.json\"\n    if pytorch_bin_map_json_path.is_file():  # not all checkpoints have this file\n        with open(pytorch_bin_map_json_path) as json_map:\n            bin_index = json.load(json_map)\n        bin_files = {checkpoint_dir / bin for bin in bin_index[\"weight_map\"].values()}\n    else:\n        bin_files = set(checkpoint_dir.glob(\"*.bin\"))\n    if not bin_files:\n        raise ValueError(f\"Expected {str(checkpoint_dir)!r} to contain .bin files\")\n\n    with incremental_save(checkpoint_dir / \"lit_model.pth\") as saver:\n        # for checkpoints that split the QKV across several files, we need to keep all the bin files\n        # open, so we use `ExitStack` to close them all together at the end\n        with contextlib.ExitStack() as stack:\n            for bin_file in sorted(bin_files):\n                print(\"Processing\", bin_file)\n                hf_weights = stack.enter_context(lazy_load(bin_file))\n                copy_fn(sd, hf_weights, saver=None, dtype=dtype)\n            gc.collect()\n        print(\"Saving converted checkpoint\")\n        saver.save(sd)\n\n\nif __name__ == \"__main__\":\n    from jsonargparse import CLI\n\n    CLI(convert_hf_checkpoint)\n"
  },
  {
    "path": "scripts/convert_lit_checkpoint.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport contextlib\nimport gc\nimport sys\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Dict, Literal, Optional, Tuple, Union\nfrom dataclasses import asdict\nimport json\nimport torch\n\n# support running without installing as a package\nwd = Path(__file__).parent.parent.resolve()\nsys.path.append(str(wd))\n\nfrom lit_gpt import Config\nfrom lit_gpt.utils import NotYetLoadedTensor, incremental_save, lazy_load\n# from scripts.convert_hf_checkpoint import layer_template, load_param\n\n\ndef layer_template(layer_name: str, idx: int) -> Tuple[str, int]:\n    split = layer_name.split(\".\")\n    number = int(split[idx])\n    split[idx] = \"{}\"\n    from_name = \".\".join(split)\n    return from_name, number\n\n\ndef load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor:\n    if hasattr(param, \"_load_tensor\"):\n        # support tensors loaded via `lazy_load()`\n        print(f\"Loading {name!r} into RAM\")\n        param = param._load_tensor()\n    if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype:\n        print(f\"Converting {name!r} from {param.dtype} to {dtype}\")\n        param = param.to(dtype)\n    return param\ndef copy_weights_falcon(\n    size: Literal[\"7b\", \"40b\"],\n    state_dict: Dict[str, torch.Tensor],\n    lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n):\n    weight_map = {\n        \"transformer.wte.weight\": \"transformer.word_embeddings.weight\",\n        \"transformer.h.{}.attn.attn.weight\": \"transformer.h.{}.self_attention.query_key_value.weight\",\n        \"transformer.h.{}.attn.proj.weight\": \"transformer.h.{}.self_attention.dense.weight\",\n        \"transformer.h.{}.mlp.fc.weight\": \"transformer.h.{}.mlp.dense_h_to_4h.weight\",\n        \"transformer.h.{}.mlp.proj.weight\": \"transformer.h.{}.mlp.dense_4h_to_h.weight\",\n        \"transformer.ln_f.bias\": \"transformer.ln_f.bias\",\n        \"transformer.ln_f.weight\": \"transformer.ln_f.weight\",\n        \"lm_head.weight\": \"lm_head.weight\",\n    }\n    # the original model definition is different for each size\n    if size == \"7b\":\n        weight_map.update(\n            {\n                \"transformer.h.{}.norm_1.bias\": \"transformer.h.{}.input_layernorm.bias\",\n                \"transformer.h.{}.norm_1.weight\": \"transformer.h.{}.input_layernorm.weight\",\n            }\n        )\n    elif size == \"40b\":\n        weight_map.update(\n            {\n                \"transformer.h.{}.norm_1.bias\": \"transformer.h.{}.ln_attn.bias\",\n                \"transformer.h.{}.norm_1.weight\": \"transformer.h.{}.ln_attn.weight\",\n                \"transformer.h.{}.norm_2.bias\": \"transformer.h.{}.ln_mlp.bias\",\n                \"transformer.h.{}.norm_2.weight\": \"transformer.h.{}.ln_mlp.weight\",\n            }\n        )\n    else:\n        raise NotImplementedError\n\n    for name, param in lit_weights.items():\n        if \"transformer.h\" in name:\n            from_name, number = layer_template(name, 2)\n            to_name = weight_map[from_name].format(number)\n        else:\n            to_name = weight_map[name]\n        param = load_param(param, name, None)\n        if saver is not None:\n            param = saver.store_early(param)\n        state_dict[to_name] = param\n\n\ndef copy_weights_gpt_neox(\n    state_dict: Dict[str, torch.Tensor],\n    lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n) -> None:\n    weight_map = {\n        \"transformer.wte.weight\": \"gpt_neox.embed_in.weight\",\n        \"transformer.h.{}.norm_1.bias\": \"gpt_neox.layers.{}.input_layernorm.bias\",\n        \"transformer.h.{}.norm_1.weight\": \"gpt_neox.layers.{}.input_layernorm.weight\",\n        \"transformer.h.{}.attn.attn.bias\": \"gpt_neox.layers.{}.attention.query_key_value.bias\",\n        \"transformer.h.{}.attn.attn.weight\": \"gpt_neox.layers.{}.attention.query_key_value.weight\",\n        \"transformer.h.{}.attn.proj.bias\": \"gpt_neox.layers.{}.attention.dense.bias\",\n        \"transformer.h.{}.attn.proj.weight\": \"gpt_neox.layers.{}.attention.dense.weight\",\n        \"transformer.h.{}.norm_2.bias\": \"gpt_neox.layers.{}.post_attention_layernorm.bias\",\n        \"transformer.h.{}.norm_2.weight\": \"gpt_neox.layers.{}.post_attention_layernorm.weight\",\n        \"transformer.h.{}.mlp.fc.bias\": \"gpt_neox.layers.{}.mlp.dense_h_to_4h.bias\",\n        \"transformer.h.{}.mlp.fc.weight\": \"gpt_neox.layers.{}.mlp.dense_h_to_4h.weight\",\n        \"transformer.h.{}.mlp.proj.bias\": \"gpt_neox.layers.{}.mlp.dense_4h_to_h.bias\",\n        \"transformer.h.{}.mlp.proj.weight\": \"gpt_neox.layers.{}.mlp.dense_4h_to_h.weight\",\n        \"transformer.ln_f.bias\": \"gpt_neox.final_layer_norm.bias\",\n        \"transformer.ln_f.weight\": \"gpt_neox.final_layer_norm.weight\",\n        \"lm_head.weight\": \"embed_out.weight\",\n    }\n\n    for name, param in lit_weights.items():\n        if \"transformer.h\" in name:\n            from_name, number = layer_template(name, 2)\n            to_name = weight_map[from_name].format(number)\n        else:\n            to_name = weight_map[name]\n        param = load_param(param, name, None)\n        if saver is not None:\n            param = saver.store_early(param)\n        state_dict[to_name] = param\n\n\ndef copy_weights_llama(\n    config: Config,\n    state_dict: Dict[str, torch.Tensor],\n    lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n):\n    weight_map = {\n        \"transformer.wte.weight\": \"model.embed_tokens.weight\",\n        \"transformer.h.{}.norm_1.weight\": \"model.layers.{}.input_layernorm.weight\",\n        \"transformer.h.{}.attn.proj.weight\": \"model.layers.{}.self_attn.o_proj.weight\",\n        \"transformer.h.{}.norm_2.weight\": \"model.layers.{}.post_attention_layernorm.weight\",\n        \"transformer.h.{}.mlp.fc_1.weight\": \"model.layers.{}.mlp.gate_proj.weight\",\n        \"transformer.h.{}.mlp.fc_2.weight\": \"model.layers.{}.mlp.up_proj.weight\",\n        \"transformer.h.{}.mlp.proj.weight\": \"model.layers.{}.mlp.down_proj.weight\",\n        \"transformer.ln_f.weight\": \"model.norm.weight\",\n        \"lm_head.weight\": \"lm_head.weight\",\n    }\n    for name, param in lit_weights.items():\n        if name.endswith(\".attn.attn.weight\"):\n            from_name, number = layer_template(name, 2)\n            q = \"model.layers.{}.self_attn.q_proj.weight\".format(number)\n            k = \"model.layers.{}.self_attn.k_proj.weight\".format(number)\n            v = \"model.layers.{}.self_attn.v_proj.weight\".format(number)\n            qkv = load_param(param, name,None)\n            qp, kp, vp = tensor_split(qkv, config)\n            for to_name, param in zip((q, k, v), (qp, kp, vp)):\n                if saver is not None:\n                    param = saver.store_early(param)\n                state_dict[to_name] = param\n        elif \"transformer.h\" in name:\n            from_name, number = layer_template(name, 2)\n            to_name = weight_map[from_name]\n            \n            if to_name is None:\n                continue\n            to_name = to_name.format(number)\n            param = load_param(param, name,None)\n            if saver is not None:\n                param = saver.store_early(param)\n            state_dict[to_name] = param\n\n        else:\n            to_name = weight_map[name]\n            param = load_param(param, name, None)\n            if saver is not None:\n                param = saver.store_early(param)\n            state_dict[to_name] = param\n\n\ndef tensor_split(\n    param: Union[torch.Tensor, NotYetLoadedTensor], config: Config\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    def kstart(start, blen, klen) -> int:\n        \"\"\"returns start index of keys in batch\"\"\"\n        return start + (blen - (klen * 2))\n\n    def vstart(start, blen, klen) -> int:\n        \"\"\"returns start index of values in batch\"\"\"\n        return start + blen - klen\n\n    def vend(start, blen) -> int:\n        \"\"\"returns last index of values in batch\"\"\"\n        return start + blen\n\n    # num observations\n    nobs = param.shape[0]\n    # batch length\n    blen = nobs // config.n_query_groups\n    # key length in batch\n    klen = config.head_size\n    # value length in batch\n    vlen = config.head_size\n    # the starting index of each new batch\n    starts = range(0, nobs, blen)\n    # the indices to splice on\n    splices = [(s, kstart(s, blen, klen), vstart(s, blen, vlen), vend(s, blen)) for s in starts]\n\n    qc = ()\n    kc = ()\n    vc = ()\n\n    for splice in splices:\n        qs, ks, vs, ve = splice\n        qc += (param[qs:ks, :],)\n        kc += (param[ks:vs, :],)\n        vc += (param[vs:ve, :],)\n\n    q = torch.cat(qc)\n    k = torch.cat(kc)\n    v = torch.cat(vc)\n\n    return q, k, v\n\n\ndef maybe_unwrap_state_dict(lit_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n    return lit_weights.get(\"model\", lit_weights)\n\n\ndef check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None:\n    weight_names = {wk.split(\".\")[-1] for wk in lit_weights}\n    # LoRA or QLoRA\n    if any(\"lora\" in wn for wn in weight_names):\n        raise ValueError(\"Model weights must be merged using `lora.merge_lora_weights()` before conversion.\")\n    # adapter v2. adapter_bias will only be in adapter_v2\n    elif \"adapter_bias\" in weight_names:\n        raise NotImplementedError(\"Converting models finetuned with adapter_v2 not yet supported.\")\n    # adapter. gating_factor is in adapter and adapter_v2\n    elif \"gating_factor\" in weight_names:\n        raise NotImplementedError(\"Converting models finetuned with adapter not yet supported.\")\n\n\ndef get_tinyllama_init_hf_config() -> dict:\n    return {\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": None,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": None,\n        \"max_position_embeddings\": None,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": None,\n        \"num_hidden_layers\": None,\n        \"num_key_value_heads\": None,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": None,\n        \"rope_scaling\": None,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"float32\",\n        \"transformers_version\": \"4.31.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": None,\n    }\n\n\ndef convert_config_lit_to_hf(lit_config_dict: dict) -> dict:\n    lit_hf_mapping = {\n        \"block_size\": \"max_position_embeddings\",\n        \"vocab_size\": \"vocab_size\",\n        \"n_layer\": \"num_hidden_layers\",\n        \"n_embd\": \"hidden_size\",\n        \"n_head\": \"num_attention_heads\",\n        \"n_query_groups\": \"num_key_value_heads\",\n        \"intermediate_size\": \"intermediate_size\",\n        \"norm_eps\": \"rms_norm_eps\",\n\n    }\n    hf_config_dict = get_tinyllama_init_hf_config()\n    \n    for lit_key, hf_key in lit_hf_mapping.items():\n        hf_config_dict[hf_key] = lit_config_dict[lit_key]\n    return hf_config_dict\n\n\n@torch.inference_mode()\ndef convert_lit_checkpoint(*, \n    checkpoint_name: str, \n    out_dir: Path, \n    model_name: str,\n    model_only: bool = True) -> None:\n    config = Config.from_name(model_name)\n\n    if \"falcon\" in model_name:\n        copy_fn = partial(copy_weights_falcon, \"40b\" if config.n_embd == 8192 else \"7b\")\n    elif config._mlp_class == \"LLaMAMLP\":\n        copy_fn = partial(copy_weights_llama, config)\n    else:\n        copy_fn = copy_weights_gpt_neox\n\n    # initialize a new empty state dict to hold our new weights\n    sd = {}\n\n    # checkpoint_name cannot be hardcoded because there exists different outputs such as\n    # (\"lit_model_finetuned.pth\", \"lit_model_lora_finetuned.pth\", \"lit_model_adapter_finetuned.pth\"\")\n    pth_file = out_dir / checkpoint_name\n    bin_file = pth_file.with_suffix(\".bin\")\n\n    with incremental_save(bin_file) as saver:\n        with contextlib.ExitStack() as stack:\n            lit_weights = stack.enter_context(lazy_load(pth_file))\n            lit_weights = maybe_unwrap_state_dict(lit_weights)\n            check_conversion_supported(lit_weights)\n            # Incremental save will trigger error\n            copy_fn(sd, lit_weights, saver=None)\n            gc.collect()\n        saver.save(sd)\n\n    # convert lit config file to hf-style\n    if not model_only:\n        print('Converting config file...')\n        lit_config = asdict(config)\n        hf_config = convert_config_lit_to_hf(lit_config)\n        config_path = out_dir / \"config.json\"\n        with open(config_path, \"w\") as f:\n            json.dump(hf_config, f, indent=4)\n\n\n\n\nif __name__ == \"__main__\":\n    from jsonargparse import CLI\n\n    CLI(convert_lit_checkpoint, as_positional=False)\n"
  },
  {
    "path": "scripts/convert_lit_model_to_hf.sh",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\niter=104000\nmodel_path=./out_bak_100m_2k_code_iter_$iter/\nmkdir $model_path\ncheckpoint_name=iter-${iter}-ckpt.pth\ncp ./out/tinyllama_135M_2k/$checkpoint_name $model_path\npython scripts/convert_lit_checkpoint.py \\\n    --checkpoint_name=$checkpoint_name\\\n    --out_dir=$model_path \\\n    --model_name='tiny_LLaMA_135M_2k' \\\n    --model_only=False\n\ncp ./scripts/tokenizer/* ${model_path}\nmv ${model_path}/iter-${iter}-ckpt.bin ${model_path}/pytorch_model.bin\nrm ${model_path}/${checkpoint_name}\n"
  },
  {
    "path": "scripts/datasets_statistics.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport glob\nimport os\n\ndata_path = '../slim_processed'\nchunk_size=4097*1024\n\nprefix = 'train_*'\ndata_split = {\n    'starcoder': 'train_starcoder',\n    'slimpajama_wiki': 'train_wikipedia_slimpajama',\n    'slimpajama_git': 'train_github_slimpajama',\n    'slimpajama_book': 'train_book_slimpajama',\n    'slimpajama': 'train_slimpajama',\n    'mnbvc': 'train_mnbvc',\n    'skypile': 'train_skypile',\n    'openwebmath': 'train_openwebmath',\n    'project_gutenberg': 'train_project_gutenberg',\n}\ndata_epoches = {\n    'starcoder': 1.0,\n    'slimpajama_wiki': 1.0,\n    'slimpajama_git': 1.0,\n    'slimpajama_book': 1.0,\n    'slimpajama': 1.0,\n    'mnbvc': 1.0,\n    'skypile': 1.0,\n    'openwebmath': 1.0,\n    'project_gutenberg': 1.0,\n}\ndata_statis = {}\nfor data_name in data_split:\n    data_statis[data_name] = 0\ntotal_chunks = 0\ntotal_tokens = 0\n\nfilenames = glob.glob(os.path.join(data_path, prefix), recursive=True)\nfor filename in filenames:\n    for data_name, pref in data_split.items():\n        if filename[len(os.path.dirname(filename))+1:].startswith(pref):\n            data_statis[data_name] += 1\n            total_chunks += 1\nprint('statistics:')\nfor data_name, num_chunk in data_statis.items():\n    print(f'{num_chunk*chunk_size/1000/1000/1000} B tokens, ', f'{num_chunk} chunks, ', data_name)\n    total_tokens += num_chunk*chunk_size\nprint(f\"{total_tokens/1000/1000/1000} B tokens\", f\"{total_chunks} chunks in total.\")\n\nprint(\"percentage:\")\nfor data_name, num_chunk in data_statis.items():\n    print(f'1.0 epoches, {num_chunk*chunk_size / total_tokens} %, ', data_name)\n\nprint(\"weighted:\")\ntotal_tokens = 0\nfor data_name, num_chunk in data_statis.items():\n    print(f'{num_chunk*chunk_size*data_epoches[data_name]/1000/1000/1000} B tokens, ', f'{num_chunk} chunks, ', data_name)\n    total_tokens += num_chunk*chunk_size*data_epoches[data_name]\n\nfor data_name, num_chunk in data_statis.items():\n    print(f'{data_epoches[data_name]} epoches, {num_chunk*chunk_size*data_epoches[data_name] / total_tokens*100} %, ', data_name)\nprint(f\"{total_tokens/1000/1000/1000} B tokens\", f\"{total_chunks} chunks in total.\")\n"
  },
  {
    "path": "scripts/prepare_mnbvc.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport json\nimport glob\nimport os\nfrom pathlib import Path\nimport sys\nfrom typing import List\nimport numpy as np\nfrom tqdm import tqdm\nfrom multiprocessing import Process, cpu_count\n\n# support running without installing as a package\nwd = Path(__file__).parent.parent.resolve()\nsys.path.append(str(wd))\n\nimport lit_gpt.packed_dataset as packed_dataset\nfrom lit_gpt import Tokenizer\n\nimport pandas as pd\nimport gzip\n\ndef prepare_full(\n    source_path: Path,\n    tokenizer_path: Path,\n    destination_path: Path,\n    chunk_size: int,\n    split: str=\"train\",\n    filenames_subset: List[str] = None,\n    process_id: int = 0\n) -> None:\n    import zstandard as zstd\n\n    destination_path.mkdir(parents=True, exist_ok=True)\n\n    tokenizer = Tokenizer(tokenizer_path)\n\n    # Use the provided filenames_subset or default to all filenames\n    filenames = filenames_subset \n    \n    if not filenames:\n        raise RuntimeError(\n            f\"No files matching  found at {source_path}. \\n\"\n            \"Make sure you download the data...\"\n        )\n\n    builder = packed_dataset.PackedDatasetBuilder(\n        outdir=destination_path,\n        prefix=f\"{split}_mnbvc_{process_id}\",  # Use process_id to differentiate builders\n        chunk_size=chunk_size,\n        sep_token=tokenizer.bos_id,\n        dtype=\"auto\",\n        vocab_size=tokenizer.vocab_size,\n    )\n\n    for filepath in filenames:\n        print(f\"Processing {filepath}\")\n        try:\n            # contents = pd.read_parquet(filepath, engine='pyarrow')['content']\n            if 'code/metadata/' in filepath:\n                print(\"Not use metadata!\")\n                continue\n            with gzip.open(open(filepath, \"rb\"), mode=\"rt\") as f:\n                for row in tqdm(f):\n                    text = json.loads(row)[\"text\"]\n                    text_ids = tokenizer.encode(text)\n                    builder.add_array(np.array(text_ids, dtype=builder.dtype))\n        except:\n            print(f\"Error reading {filepath}!!\")\n            continue\n\n    # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details\n    # builder.write_reminder()\n\n\ndef prepare(\n    source_path: Path = Path(\"data/RedPajama-Data-1T-Sample\"),\n    tokenizer_path: Path = Path(\"checkpoints/lit-llama/tokenizer.model\"),\n    destination_path: Path = Path(\"data/red_pajama_sample\"),\n    chunk_size: int = 4097 * 1024,\n    split: str=\"train\",\n    percentage: float = 1.0,\n    filenames_subset: List[str] = None,\n) -> None:\n    import time\n    assert split == \"train\" #  starcoder only has train data\n    filenames = glob.glob(os.path.join(source_path, \"*/*/*.jsonl.gz\"), recursive=True)\n    filenames += glob.glob(os.path.join(source_path, \"*/*/*/*.jsonl.gz\"), recursive=True)\n    print(len(filenames))\n    # only retrain subsets that follow the prefix in filenames_subset\n    if filenames_subset:\n        filenames = [f for f in filenames if any([prefix in f for prefix in filenames_subset])]\n    filenames = filenames[:int(len(filenames) * percentage)]\n    num_processes = 64\n    chunked_filenames = np.array_split(filenames, num_processes)\n\n    processes = []\n    start_time = time.time()\n\n    for i, subset in enumerate(chunked_filenames):\n        p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i))\n        processes.append(p)\n        p.start()\n\n    for p in processes:\n        p.join()\n    end_time = time.time()\n    elapsed_time = end_time - start_time\n    print(f\"Time taken: {elapsed_time:.2f} seconds\")\n\n\nif __name__ == \"__main__\":\n    from jsonargparse import CLI\n    CLI(prepare)\n"
  },
  {
    "path": "scripts/prepare_mnbvc.sh",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nsource_path=../MNBVC\ntarget_path=../mnbvc_processed/\ntokenizer_path=./scripts/tokenizer\n\npython ./scripts/prepare_mnbvc.py \\\n    --source_path $source_path \\\n    --tokenizer_path $tokenizer_path  \\\n    --destination_path $target_path \\\n    --split train \\\n    --percentage 1.0\n"
  },
  {
    "path": "scripts/prepare_project_gutenberg.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport json\nimport glob\nimport os\nfrom pathlib import Path\nimport sys\nfrom typing import List\nimport numpy as np\nfrom tqdm import tqdm\nfrom multiprocessing import Process, cpu_count\n\n# support running without installing as a package\nwd = Path(__file__).parent.parent.resolve()\nsys.path.append(str(wd))\n\nimport lit_gpt.packed_dataset as packed_dataset\nfrom lit_gpt import Tokenizer\n\nimport pandas as pd\n\n\ndef prepare_full(\n    source_path: Path,\n    tokenizer_path: Path,\n    destination_path: Path,\n    chunk_size: int,\n    split: str=\"train\",\n    filenames_subset: List[str] = None,\n    process_id: int = 0\n) -> None:\n    import zstandard as zstd\n\n    destination_path.mkdir(parents=True, exist_ok=True)\n\n    tokenizer = Tokenizer(tokenizer_path)\n\n    # Use the provided filenames_subset or default to all filenames\n    filenames = filenames_subset \n    \n    if not filenames:\n        raise RuntimeError(\n            f\"No files matching  found at {source_path}. \\n\"\n            \"Make sure you download the data...\"\n        )\n\n    builder = packed_dataset.PackedDatasetBuilder(\n        outdir=destination_path,\n        prefix=f\"{split}_project_gutenberg_{process_id}\",  # Use process_id to differentiate builders\n        chunk_size=chunk_size,\n        sep_token=tokenizer.bos_id,\n        dtype=\"auto\",\n        vocab_size=tokenizer.vocab_size,\n    )\n\n    for filepath in filenames:\n        print(f\"Processing {filepath}\")\n        print(filepath)\n        contents = pd.read_parquet(filepath, engine='pyarrow')['text']\n        #try:\n        #    contents = pd.read_parquet(filepath, engine='pyarrow')['text']\n        #except:\n        #    print(f\"Error reading {filepath}!!\")\n        #    continue\n        for text in contents:\n            text_ids = tokenizer.encode(text)\n            builder.add_array(np.array(text_ids, dtype=builder.dtype))\n\n    # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details\n    # builder.write_reminder()\n\n\ndef prepare(\n    source_path: Path = Path(\"data/RedPajama-Data-1T-Sample\"),\n    tokenizer_path: Path = Path(\"checkpoints/lit-llama/tokenizer.model\"),\n    destination_path: Path = Path(\"data/red_pajama_sample\"),\n    chunk_size: int = 2049 * 2048,\n    split: str=\"train\",\n    percentage: float = 1.0,\n    filenames_subset: List[str] = None,\n) -> None:\n    import time\n    assert split == \"train\" #  starcoder only has train data\n    filenames = glob.glob(os.path.join(source_path, \"*/*.parquet\"), recursive=True)\n    # only retrain subsets that follow the prefix in filenames_subset\n    if filenames_subset:\n        filenames = [f for f in filenames if any([prefix in f for prefix in filenames_subset])]\n    filenames = filenames[:int(len(filenames) * percentage)]\n    num_processes = min(cpu_count(), len(filenames))\n    chunked_filenames = np.array_split(filenames, num_processes)\n\n    processes = []\n    start_time = time.time()\n\n    for i, subset in enumerate(chunked_filenames):\n        p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i))\n        processes.append(p)\n        p.start()\n\n    for p in processes:\n        p.join()\n    end_time = time.time()\n    elapsed_time = end_time - start_time\n    print(f\"Time taken: {elapsed_time:.2f} seconds\")\n\n\nif __name__ == \"__main__\":\n    from jsonargparse import CLI\n    CLI(prepare)\n"
  },
  {
    "path": "scripts/prepare_project_gutenberg.sh",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nsource_path=../project_gutenberg\ntarget_path=../slim_processed\n\n# train: 873 secs, ~20G\ntokenizer_path=./scripts/tokenizer\n\npython ./scripts/prepare_project_gutenberg.py \\\n    --source_path $source_path \\\n    --tokenizer_path $tokenizer_path  \\\n    --destination_path $target_path \\\n    --split train \\\n    --percentage 1.0\n"
  },
  {
    "path": "scripts/prepare_skypile.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport json\nimport glob\nimport os\nfrom pathlib import Path\nimport sys\nfrom typing import List\nimport numpy as np\nfrom tqdm import tqdm\nfrom multiprocessing import Process, cpu_count\n\n# support running without installing as a package\nwd = Path(__file__).parent.parent.resolve()\nsys.path.append(str(wd))\n\nimport lit_gpt.packed_dataset as packed_dataset\nfrom lit_gpt import Tokenizer\n\nimport pandas as pd\n\n\ndef prepare_full(\n    source_path: Path,\n    tokenizer_path: Path,\n    destination_path: Path,\n    chunk_size: int,\n    split: str=\"train\",\n    filenames_subset: List[str] = None,\n    process_id: int = 0\n) -> None:\n    import zstandard as zstd\n\n    destination_path.mkdir(parents=True, exist_ok=True)\n\n    tokenizer = Tokenizer(tokenizer_path)\n\n    # Use the provided filenames_subset or default to all filenames\n    filenames = filenames_subset \n   \n    if not filenames:\n        raise RuntimeError(\n            f\"No files matching  found at {source_path}. \\n\"\n            \"Make sure you download the data...\"\n        )\n\n    builder = packed_dataset.PackedDatasetBuilder(\n        outdir=destination_path,\n        prefix=f\"{split}_skypile_{process_id}\",  # Use process_id to differentiate builders\n        chunk_size=chunk_size,\n        sep_token=tokenizer.bos_id,\n        dtype=\"auto\",\n        vocab_size=tokenizer.vocab_size,\n    )\n\n    for filepath in filenames:\n        print(f\"Processing {filepath}\")\n        try:\n            # contents = pd.read_parquet(filepath, engine='pyarrow')['content']\n            contents = pd.read_json(path_or_buf=filepath, lines=True)['text']\n        except:\n            print(f\"Error reading {filepath}!!\")\n            continue\n        for text in contents:\n            text_ids = tokenizer.encode(text)\n            builder.add_array(np.array(text_ids, dtype=builder.dtype))\n\n    # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details\n    # builder.write_reminder()\n\n\ndef prepare(\n    source_path: Path = Path(\"data/RedPajama-Data-1T-Sample\"),\n    tokenizer_path: Path = Path(\"checkpoints/lit-llama/tokenizer.model\"),\n    destination_path: Path = Path(\"data/red_pajama_sample\"),\n    chunk_size: int = 2048 * 2049,\n    split: str=\"train\",\n    percentage: float = 1.0,\n    filenames_subset: List[str] = None,\n) -> None:\n    import time\n    assert split == \"train\" #  starcoder only has train data\n    filenames = glob.glob(os.path.join(source_path, \"*/*.jsonl\"), recursive=True)\n    # only retrain subsets that follow the prefix in filenames_subset\n    if filenames_subset:\n        filenames = [f for f in filenames if any([prefix in f for prefix in filenames_subset])]\n    filenames = filenames[:int(len(filenames) * percentage)]\n    num_processes = cpu_count()\n    chunked_filenames = np.array_split(filenames, num_processes)\n\n    processes = []\n    start_time = time.time()\n\n    for i, subset in enumerate(chunked_filenames):\n        p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i))\n        processes.append(p)\n        p.start()\n\n    for p in processes:\n        p.join()\n    end_time = time.time()\n    elapsed_time = end_time - start_time\n    print(f\"Time taken: {elapsed_time:.2f} seconds\")\n\n\nif __name__ == \"__main__\":\n    from jsonargparse import CLI\n    CLI(prepare)\n"
  },
  {
    "path": "scripts/prepare_skypile.sh",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nsource_path=../SkyPile-150B/\ntarget_path=../skypile_processed/\ntokenizer_path=./scripts/tokenizer\n\npython ./scripts/prepare_skypile.py \\\n    --source_path $source_path \\\n    --tokenizer_path $tokenizer_path  \\\n    --destination_path $target_path \\\n    --split train \\\n    --percentage 1.0\n"
  },
  {
    "path": "scripts/prepare_slimpajama.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport json\nimport glob\nimport os\nfrom pathlib import Path\nimport sys\nfrom typing import List\nimport numpy as np\nfrom tqdm import tqdm\nfrom multiprocessing import Process, cpu_count\n\n# support running without installing as a package\nwd = Path(__file__).parent.parent.resolve()\nsys.path.append(str(wd))\n\nimport lit_gpt.packed_dataset as packed_dataset\nfrom lit_gpt.tokenizer import Tokenizer\n\n# Filename for SlimPajama\nslimpajama_sets = {\n    \"train\": \"train/chunk*/*\",\n    \"validation\": \"validation/chunk*/*\",\n    \"test\": \"test/chunk*/*\",\n}\n\n\ndef prepare_full(\n    source_path: Path,\n    tokenizer_path: Path,\n    destination_path: Path,\n    chunk_size: int,\n    split: str=\"train\",\n    filenames_subset: List[str] = None,\n    process_id: int = 0\n) -> None:\n    import zstandard as zstd\n\n    destination_path.mkdir(parents=True, exist_ok=True)\n\n    tokenizer = Tokenizer(tokenizer_path)\n    print(tokenizer_path, ' DDDDDDDDDDDDDDDDD')\n    print(tokenizer.bos_id)\n\n    # Use the provided filenames_subset or default to all filenames\n    filenames = filenames_subset \n    \n    if not filenames:\n        raise RuntimeError(\n            f\"No files matching {slimpajama_sets[split]} found at {source_path}. \\n\"\n            \"Make sure you download the data...\"\n        )\n\n    builder = packed_dataset.PackedDatasetBuilder(\n        outdir=destination_path,\n        prefix=f\"{split}_slimpajama_{process_id}\",  # Use process_id to differentiate builders\n        chunk_size=chunk_size,\n        sep_token=tokenizer.bos_id,\n        dtype=\"auto\",\n        vocab_size=tokenizer.vocab_size,\n    )\n    builder_wiki = packed_dataset.PackedDatasetBuilder(\n        outdir=destination_path,\n        prefix=f\"{split}_wikipedia_slimpajama_{process_id}\",  # Use process_id to differentiate builders\n        chunk_size=chunk_size,\n        sep_token=tokenizer.bos_id,\n        dtype=\"auto\",\n        vocab_size=tokenizer.vocab_size,\n    )\n    for filepath in filenames:\n        print(f\"Processing {filepath}\")\n        with zstd.open(open(filepath, \"rb\"), \"rt\", encoding=\"utf-8\") as f:\n            for row in tqdm(f):\n                text = json.loads(row)[\"text\"]\n                text_ids = tokenizer.encode(text)\n                if json.loads(row)[\"meta\"][\"redpajama_set_name\"]=='RedPajamaBook':\n                    print(\"skip red pajama book!!!\")\n                    continue\n                if split == 'train' and json.loads(row)[\"meta\"][\"redpajama_set_name\"] == \"RedPajamaWikipedia\":\n                    builder_wiki.add_array(np.array(text_ids, dtype=builder_wiki.dtype))\n                else:\n                    builder.add_array(np.array(text_ids, dtype=builder.dtype))\n\n    # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details\n    # builder.write_reminder()\n\n\ndef prepare(\n    source_path: Path = Path(\"data/RedPajama-Data-1T-Sample\"),\n    tokenizer_path: Path = Path(\"checkpoints/lit-llama/tokenizer.model\"),\n    destination_path: Path = Path(\"data/red_pajama_sample\"),\n    chunk_size: int = 2048 * 2049,\n    split: str=\"train\",\n    percentage: float = 1.0,\n) -> None:\n    import time\n\n    filenames = glob.glob(os.path.join(source_path, slimpajama_sets[split]), recursive=True)\n    filenames = filenames[:int(len(filenames) * percentage)]\n\n    num_processes = 16#cpu_count()\n    chunked_filenames = np.array_split(filenames, num_processes)\n\n    processes = []\n    start_time = time.time()\n\n    for i, subset in enumerate(chunked_filenames):\n        p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i))\n        processes.append(p)\n        p.start()\n\n    for p in processes:\n        p.join()\n    end_time = time.time()\n    elapsed_time = end_time - start_time\n    print(f\"Time taken: {elapsed_time:.2f} seconds\")\n\n\nif __name__ == \"__main__\":\n    from jsonargparse import CLI\n    CLI(prepare)\n"
  },
  {
    "path": "scripts/prepare_slimpajama_train.sh",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nsource_path=../SlimPajama-627B\ntarget_path=../slim_processed\ntokenizer_path=./scripts/tokenizer\n\npython3 ./scripts/prepare_slimpajama.py \\\n    --source_path $source_path \\\n    --tokenizer_path $tokenizer_path  \\\n    --destination_path $target_path \\\n    --split train \\\n    --percentage 1.0\n"
  },
  {
    "path": "scripts/prepare_slimpajama_valid.sh",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nsource_path=../SlimPajama-627B\ntarget_path=../slim_validation_processed\ntokenizer_path=./scripts/tokenizer\npython3 ./scripts/prepare_slimpajama.py \\\n    --source_path $source_path \\\n    --tokenizer_path $tokenizer_path  \\\n    --destination_path $target_path \\\n    --split validation \\\n    --percentage 1.0\n"
  },
  {
    "path": "scripts/prepare_starcoder.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport json\nimport glob\nimport os\nfrom pathlib import Path\nimport sys\nfrom typing import List\nimport numpy as np\nfrom tqdm import tqdm\nfrom multiprocessing import Process, cpu_count\n\n# support running without installing as a package\nwd = Path(__file__).parent.parent.resolve()\nsys.path.append(str(wd))\n\nimport lit_gpt.packed_dataset as packed_dataset\nfrom lit_gpt import Tokenizer\n\nimport pandas as pd\n\n\ndef prepare_full(\n    source_path: Path,\n    tokenizer_path: Path,\n    destination_path: Path,\n    chunk_size: int,\n    split: str=\"train\",\n    filenames_subset: List[str] = None,\n    process_id: int = 0\n) -> None:\n    import zstandard as zstd\n\n    destination_path.mkdir(parents=True, exist_ok=True)\n\n    tokenizer = Tokenizer(tokenizer_path)\n\n    # Use the provided filenames_subset or default to all filenames\n    filenames = filenames_subset \n    \n    if not filenames:\n        raise RuntimeError(\n            f\"No files matching  found at {source_path}. \\n\"\n            \"Make sure you download the data...\"\n        )\n\n    builder = packed_dataset.PackedDatasetBuilder(\n        outdir=destination_path,\n        prefix=f\"{split}_starcoder_{process_id}\",  # Use process_id to differentiate builders\n        chunk_size=chunk_size,\n        sep_token=tokenizer.bos_id,\n        dtype=\"auto\",\n        vocab_size=tokenizer.vocab_size,\n    )\n\n    for filepath in filenames:\n        print(f\"Processing {filepath}\")\n        try:\n            contents = pd.read_parquet(filepath, engine='pyarrow')['content']\n        except:\n            print(f\"Error reading {filepath}!!\")\n            continue\n        for text in contents:\n            text_ids = tokenizer.encode(text)\n            builder.add_array(np.array(text_ids, dtype=builder.dtype))\n\n    # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details\n    # builder.write_reminder()\n\n\ndef prepare(\n    source_path: Path = Path(\"data/RedPajama-Data-1T-Sample\"),\n    tokenizer_path: Path = Path(\"checkpoints/lit-llama/tokenizer.model\"),\n    destination_path: Path = Path(\"data/red_pajama_sample\"),\n    chunk_size: int = 2048 * 2049,\n    split: str=\"train\",\n    percentage: float = 1.0,\n    filenames_subset: List[str] = None,\n) -> None:\n    import time\n    assert split == \"train\" #  starcoder only has train data\n    filenames = glob.glob(os.path.join(source_path, \"*/*.parquet\"), recursive=True)\n    # only retrain subsets that follow the prefix in filenames_subset\n    if filenames_subset:\n        filenames = [f for f in filenames if any([prefix in f for prefix in filenames_subset])]\n    filenames = filenames[:int(len(filenames) * percentage)]\n    num_processes = cpu_count()#64\n    chunked_filenames = np.array_split(filenames, num_processes)\n\n    processes = []\n    start_time = time.time()\n\n    for i, subset in enumerate(chunked_filenames):\n        p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i))\n        processes.append(p)\n        p.start()\n\n    for p in processes:\n        p.join()\n    end_time = time.time()\n    elapsed_time = end_time - start_time\n    print(f\"Time taken: {elapsed_time:.2f} seconds\")\n\n\nif __name__ == \"__main__\":\n    from jsonargparse import CLI\n    CLI(prepare)\n"
  },
  {
    "path": "scripts/prepare_starcoder.sh",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nsource_path=./starcoderdata/\ntarget_path=./starcoderdata_processed/\n#train: 12549 secs, ~500G\ntokenizer_path=./scripts/tokenizer\n\npython prepare_starcoder.py \\\n    --source_path $source_path \\\n    --tokenizer_path $tokenizer_path  \\\n    --destination_path $target_path \\\n    --split train \\\n    --percentage 1.0\n"
  },
  {
    "path": "scripts/prepare_starcoder_python.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport json\nimport glob\nimport os\nfrom pathlib import Path\nimport sys\nfrom typing import List\nimport numpy as np\nfrom tqdm import tqdm\nfrom multiprocessing import Process, cpu_count\n\n# support running without installing as a package\nwd = Path(__file__).parent.parent.resolve()\nsys.path.append(str(wd))\n\nimport lit_gpt.packed_dataset as packed_dataset\nfrom lit_gpt import Tokenizer\n\nimport pandas as pd\n\n\ndef prepare_full(\n    source_path: Path,\n    tokenizer_path: Path,\n    destination_path: Path,\n    chunk_size: int,\n    split: str=\"train\",\n    filenames_subset: List[str] = None,\n    process_id: int = 0\n) -> None:\n    import zstandard as zstd\n\n    destination_path.mkdir(parents=True, exist_ok=True)\n\n    tokenizer = Tokenizer(tokenizer_path)\n\n    # Use the provided filenames_subset or default to all filenames\n    filenames = filenames_subset \n    \n    if not filenames:\n        raise RuntimeError(\n            f\"No files matching  found at {source_path}. \\n\"\n            \"Make sure you download the data...\"\n        )\n\n    builder = packed_dataset.PackedDatasetBuilder(\n        outdir=destination_path,\n        prefix=f\"{split}_starcoder_{process_id}\",  # Use process_id to differentiate builders\n        chunk_size=chunk_size,\n        sep_token=tokenizer.bos_id,\n        dtype=\"auto\",\n        vocab_size=tokenizer.vocab_size,\n    )\n\n    for filepath in filenames:\n        print(f\"Processing {filepath}\")\n        contents = pd.read_parquet(filepath, engine='pyarrow')['content']\n        try:\n            contents = pd.read_parquet(filepath, engine='pyarrow')['content']\n        except:\n            print(f\"Error reading {filepath}!!\")\n            continue\n        for text in contents:\n            text_ids = tokenizer.encode(text)\n            builder.add_array(np.array(text_ids, dtype=builder.dtype))\n\n    # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details\n    # builder.write_reminder()\n\n\ndef prepare(\n    source_path: Path = Path(\"data/RedPajama-Data-1T-Sample\"),\n    tokenizer_path: Path = Path(\"checkpoints/lit-llama/tokenizer.model\"),\n    destination_path: Path = Path(\"data/red_pajama_sample\"),\n    chunk_size: int = 2048 * 2049,\n    split: str=\"train\",\n    percentage: float = 1.0,\n    filenames_subset: List[str] = None,\n) -> None:\n    import time\n    assert split == \"train\" #  starcoder only has train data\n    filenames = glob.glob(os.path.join(source_path, \"python/*.parquet\"), recursive=True)\n    # only retrain subsets that follow the prefix in filenames_subset\n    if filenames_subset:\n        filenames = [f for f in filenames if any([prefix in f for prefix in filenames_subset])]\n    filenames = filenames[:int(len(filenames) * percentage)]\n    num_processes = min(len(filenames), cpu_count())\n    chunked_filenames = np.array_split(filenames, num_processes)\n\n    processes = []\n    start_time = time.time()\n\n    for i, subset in enumerate(chunked_filenames):\n        p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i))\n        processes.append(p)\n        p.start()\n\n    for p in processes:\n        p.join()\n    end_time = time.time()\n    elapsed_time = end_time - start_time\n    print(f\"Time taken: {elapsed_time:.2f} seconds\")\n\n\nif __name__ == \"__main__\":\n    from jsonargparse import CLI\n    CLI(prepare)\n"
  },
  {
    "path": "scripts/prepare_starcoder_python.sh",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nsource_path=../starcoderdata/\ntarget_path=../starcoderdata_python_processed/\n#train: 12549 secs, ~500G\ntokenizer_path=./scripts/tokenizer\n\npython ./scripts/prepare_starcoder_python.py \\\n    --source_path $source_path \\\n    --tokenizer_path $tokenizer_path  \\\n    --destination_path $target_path \\\n    --split train \\\n    --percentage 1.0\n"
  },
  {
    "path": "scripts/run_lm_eval.sh",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nexport HF_DATASETS_CACHE=\"./huggingface_data\"\n\nfor task in wikitext lambada_openai winogrande piqa sciq wsc arc_easy arc_challenge logiqa hellaswag mmlu truthfulqa gsm8k ceval-valid\ndo\n    export CUDA_VISIBLE_DEVICES=\"0\"\n    lm_eval --model hf \\\n        --tasks $task \\\n        --model_args pretrained=/path/to/your/huggingface/model \\\n        --device cuda:0 \\\n        --batch_size 2 \ndone\n"
  },
  {
    "path": "speculative_decoding/codellama_spec.py",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport sys\nimport argparse\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList\nimport torch\nimport time\nfrom datasets import load_from_disk\nimport numpy as np\nimport json\nimport random\nimport os\nimport tabulate as tab\nfrom collections import defaultdict\nfrom typing import Iterable, Dict\nimport gzip\nimport logging\n\nclass LlamaModelEval(LlamaForCausalLM):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.tokenizer = None\n        self.model_name = None\n\n    def forward(self, *args, **kwargs):\n        st = time.perf_counter()\n        outputs = super().forward(*args, **kwargs)\n        en = time.perf_counter()\n        logging.critical(f\"[PROFILE] model_decoder_forward {en-st}\")\n        return outputs\n\n\nclass LlamaModelEval_Draft(LlamaForCausalLM):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.tokenizer = None\n        self.model_name = None\n\n    def forward(self, *args, **kwargs):\n        st = time.perf_counter()\n        outputs = super().forward(*args, **kwargs) \n        en = time.perf_counter()\n        logging.critical(f\"[PROFILE] draft_model_decoder_forward {en-st}\")\n        codellama_extra_token_number = 16\n        outputs['logits'] = torch.nn.functional.pad(outputs['logits'], (0, codellama_extra_token_number), value=float('-inf'))\n        return outputs\n\n\n\nclass ProfileLLM:\n    start_idx_arr = []\n    end_idx_arr = []\n    generate_times_arr = []\n    num_tokens_out_arr = []\n    num_tokens_in_arr = []\n    exec_time_arr = []\n    prompt_time_arr = []\n    token_time_arr = []\n    decoder_time_arr = []\n\n    @classmethod\n    def clear_entries(cls):\n        cls.start_idx_arr = []\n        cls.end_idx_arr = []\n        cls.generate_times_arr = []\n        cls.num_tokens_out_arr = []\n        cls.num_tokens_in_arr = []\n        cls.exec_time_arr = []\n        cls.prompt_time_arr = []\n        cls.token_time_arr = []\n        cls.decoder_time_arr = []\n\n    @classmethod\n    def collect_sections(cls, logf):\n        ### Default is not using AG\n        ag = False\n        for i, line in enumerate(logf):\n            line = line.lstrip().rstrip().split(\" \")\n            if len(line)>1:\n                #print(f\"line: {line}\")\n                if line[1] == \"tokenizer:\":\n                    cls.start_idx_arr.append(i + 1)\n                elif line[1] == \"generate:\":\n                #     import pdb\n                #     pdb.set_trace()\n                    cls.end_idx_arr.append(i)\n                    t = float(line[2])\n                    cls.generate_times_arr.append(t)\n                    num_tokens_out = int(line[4])\n                    cls.num_tokens_out_arr.append(num_tokens_out)\n                    num_tokens_in = int(line[7].rstrip(\";\"))\n                    cls.num_tokens_in_arr.append(num_tokens_in)\n                if \"draft_model_decoder_forward\" in line:\n                    ag = True\n        print(f\"\\n\\nNumber of prompts found in log: {len(cls.start_idx_arr)}\\n\")\n        return ag\n\n    @classmethod\n    def parse_section(\n        cls,\n        outlog,\n        filename,\n        prompt_num,\n        logf,\n        start_idx,\n        end_idx,\n        generate_time,\n        num_tokens_out,\n        num_tokens_in\n        ):\n        cls.exec_time_arr = []\n        cls.prompt_time_arr = []\n        cls.token_time_arr = []\n        cls.decoder_time_arr = []\n        cls.decoder_time_arr_draft = []\n\n        for i in range(start_idx, end_idx, 1):\n            line = logf[i].lstrip().rstrip().split(\" \")\n            # print(f\"line : {line}\")\n            if line[1] != \"model_decoder_forward\":\n                #print(f\"line : {line}\")\n                m = int(line[1])\n                k = int(line[2])\n                n = int(line[4])\n                exec_time_start = float(line[11])\n                exec_time_end = float(line[12])\n                exec_time = (exec_time_end - exec_time_start)\n                cls.exec_time_arr.append(exec_time)\n                if m > 1:\n                    cls.prompt_time_arr.append(exec_time)\n                else:\n                    cls.token_time_arr.append(exec_time)\n            elif line[1] == \"model_decoder_forward\":\n                #print(f\"line : {line}\")\n                decoder_time = float(line[2])\n                cls.decoder_time_arr.append(decoder_time)\n            else:\n                decoder_time = float(line[2])\n                cls.decoder_time_arr_draft.append(decoder_time)\n\n        matmul_prompt_time = sum(cls.prompt_time_arr)\n        matmul_token_time = sum(cls.token_time_arr)\n        matmul_cumulative_time = sum(cls.exec_time_arr)\n        other_layers_time = (generate_time - matmul_cumulative_time )\n        # import pdb\n        # pdb.set_trace()\n        new_tokens_generated = num_tokens_out - num_tokens_in\n\n        if len(cls.decoder_time_arr) > 0:\n            decoder_time_prefill_phase = cls.decoder_time_arr[0]\n            # decoder_time_token_phase = sum(cls.decoder_time_arr[1:])\n            decoder_time_token_phase = generate_time - decoder_time_prefill_phase\n            prefill_phase = decoder_time_prefill_phase * 1e3\n            if new_tokens_generated > 1:\n                time_per_token = (decoder_time_token_phase *1e3)/ (new_tokens_generated - 1)\n                # tokens_per_sec = 1000.0 / time_per_token\n                tokens_per_sec = new_tokens_generated / generate_time \n            else:\n                time_per_token = \"na\"\n                tokens_per_sec = \"na\"\n        else:\n            decoder_time_prefill_phase = \"na\"\n            decoder_time_token_phase = \"na\"\n            prefill_phase = \"na\"\n            time_per_token = \"na\"\n            tokens_per_sec = \"na\"\n\n        outlog.write(\n            f\"{filename},{prompt_num+1},{num_tokens_in},{matmul_prompt_time:.3f},{matmul_token_time:.3f},{matmul_cumulative_time:.3f},{other_layers_time:.3f},{generate_time:.3f},{decoder_time_prefill_phase},{decoder_time_token_phase},{num_tokens_out},{new_tokens_generated},{prefill_phase},{time_per_token},{tokens_per_sec}\\n\")\n\n        outlog.flush()\n\n        return [\n            prompt_num + 1,\n            num_tokens_in,\n            new_tokens_generated,\n            generate_time ,\n            prefill_phase,\n            time_per_token,\n            tokens_per_sec,\n            1,\n            1,\n            1\n        ]\n\n    \n\n    @classmethod\n    def parse_section_ag(\n        cls,\n        outlog,\n        filename,\n        prompt_num,\n        logf,\n        start_idx,\n        end_idx,\n        generate_time,\n        num_tokens_out,\n        num_tokens_in\n        ):\n        cls.exec_time_arr = []\n        cls.prompt_time_arr = []\n        cls.token_time_arr = []\n        cls.decoder_time_arr = []\n        cls.decoder_time_arr_draft = []\n        cls.valid_tokens = []\n\n        for i in range(start_idx, end_idx, 1):\n            line = logf[i].lstrip().rstrip().split(\" \")\n            # print(f\"line : {line}\")\n            if line[1] != \"model_decoder_forward\" and line[1] != \"draft_model_decoder_forward\" and line[1] != \"valid_tokens\":\n                m = int(line[1])\n                k = int(line[2])\n                n = int(line[4])\n                exec_time_start = float(line[11])\n                exec_time_end = float(line[12])\n                exec_time = (exec_time_end - exec_time_start)\n                cls.exec_time_arr.append(exec_time)\n                if m > 1:\n                    cls.prompt_time_arr.append(exec_time)\n                else:\n                    cls.token_time_arr.append(exec_time)\n            elif line[1] == \"valid_tokens\":\n                valid_tokens = float(line[2])\n                cls.valid_tokens.append(valid_tokens)\n            elif line[1] == \"model_decoder_forward\":\n                decoder_time = float(line[2])\n                cls.decoder_time_arr.append(decoder_time)\n            else:\n                decoder_time = float(line[2])\n                cls.decoder_time_arr_draft.append(decoder_time)\n        # import pdb\n        # pdb.set_trace()\n        matmul_prompt_time = sum(cls.prompt_time_arr)\n        matmul_token_time = sum(cls.token_time_arr)\n        matmul_cumulative_time = sum(cls.exec_time_arr)\n        other_layers_time = (generate_time - matmul_cumulative_time )\n        # import pdb\n        # pdb.set_trace()\n        new_tokens_generated = num_tokens_out - num_tokens_in\n\n        if len(cls.decoder_time_arr) > 0:\n            # decoder_time_prefill_phase = cls.decoder_time_arr[0]\n            # decoder_time_token_phase = sum(cls.decoder_time_arr[1:])\n            decoder_time_prefill_phase = cls.decoder_time_arr[0] + sum(cls.decoder_time_arr_draft[:5])\n            # decoder_time_token_phase = sum(cls.decoder_time_arr[1:]) + sum(cls.decoder_time_arr_draft[5:])\n            decoder_time_token_phase = generate_time - decoder_time_prefill_phase\n\n            accepted_num = sum(cls.valid_tokens) - len(cls.decoder_time_arr)\n            guessed_num = len(cls.decoder_time_arr_draft)\n\n            prefill_phase = decoder_time_prefill_phase * 1e3\n            if new_tokens_generated > 1:\n                # time_per_token = (decoder_time_token_phase *1e3)/ (new_tokens_generated - 1)\n                time_per_token = (decoder_time_token_phase *1e3)/ (new_tokens_generated - cls.valid_tokens[0])\n                tokens_per_sec = new_tokens_generated / generate_time \n            else:\n                time_per_token = \"na\"\n                tokens_per_sec = \"na\"\n        else:\n            decoder_time_prefill_phase = \"na\"\n            decoder_time_token_phase = \"na\"\n            prefill_phase = \"na\"\n            time_per_token = \"na\"\n            tokens_per_sec = \"na\"\n            accept_rate = \"na\"\n\n        outlog.write(\n            f\"{filename},{prompt_num+1},{num_tokens_in},{matmul_prompt_time:.3f},{matmul_token_time:.3f},{matmul_cumulative_time:.3f},{other_layers_time:.3f},{generate_time:.3f},{decoder_time_prefill_phase},{decoder_time_token_phase},{num_tokens_out},{new_tokens_generated},{prefill_phase},{time_per_token},{tokens_per_sec}\\n\")\n\n        outlog.flush()\n\n        return [\n            prompt_num + 1,\n            num_tokens_in,\n            new_tokens_generated,\n            generate_time ,\n            prefill_phase,\n            time_per_token,\n            tokens_per_sec,\n            accepted_num,\n            guessed_num,\n            cls.valid_tokens[0]\n        ]\n\n    @classmethod\n    def analyze_profiling(cls, in_file, out_file):\n        out_file.write(\n            \"Filename,Example#,Num_Tokens_In,MatMul_time_Prefill_phase[s],MatMul_time_Token_phase[s],MatMul_time_Cumulative[s],All_Other_layers[s],Generate_Time[s],Decoder_time_Prefill_phase[s],Decoder_time_Token_phase[s],Num_Tokens_Out,Num_New_Tokens,Prefill_Phase[ms],Time_per_Token[ms],Tokens\\sec\\n\"\n        )\n        with open(in_file, \"r\") as f:\n            logf = f.readlines()\n            ag = cls.collect_sections(logf)\n\n            perf_table = [\n                [\n                    \"Example#\",\n                    \"Prompt Length (tokens)\",\n                    \"New Tokens Generated\",\n                    \"Total Time (s)\",\n                    \"Prefill Phase (ms)\",\n                    \"Time/Token (ms)\",\n                    \"Tokens/Sec\",\n                    \"Accept num\",\n                    \"Guess num\",\n                    \"New Tokens Generated at first step\", \n                ]\n            ]\n            if ag:\n                for i in range(len(cls.start_idx_arr)):\n                    perf_table.append(\n                        cls.parse_section_ag(\n                            out_file,\n                            in_file,\n                            i,\n                            logf,\n                            cls.start_idx_arr[i],\n                            cls.end_idx_arr[i],\n                            cls.generate_times_arr[i],\n                            cls.num_tokens_out_arr[i],\n                            cls.num_tokens_in_arr[i],\n                        )\n                    )\n            else:\n                for i in range(len(cls.start_idx_arr)):\n                    perf_table.append(\n                        cls.parse_section(\n                            out_file,\n                            in_file,\n                            i,\n                            logf,\n                            cls.start_idx_arr[i],\n                            cls.end_idx_arr[i],\n                            cls.generate_times_arr[i],\n                            cls.num_tokens_out_arr[i],\n                            cls.num_tokens_in_arr[i],\n                        )\n                    )\n            print(tab.tabulate(perf_table, headers=\"firstrow\", tablefmt=\"github\"))\n        inference_Time_model_generate = 0\n        Prefill_time = 0\n        generated_token_number = 0\n        generated_token_number_at_first_step = 0\n        max_generated_token_number = -1\n        min_generated_token_number = 10000000\n        Decoding_Latency = 0\n        Troughput = 0\n        accept_num = 0\n        guess_num = 0\n        for pt in perf_table[1:]:\n            pt = [float(elem) for elem in pt]\n            inference_Time_model_generate += pt[3]\n            Prefill_time += pt[4]/1e3\n            generated_token_number += pt[2]\n            max_generated_token_number = pt[2] if max_generated_token_number < pt[2] else max_generated_token_number\n            min_generated_token_number = pt[2] if min_generated_token_number > pt[2] else min_generated_token_number\n            generated_token_number_at_first_step += pt[-1]\n            accept_num += pt[-3]\n            guess_num += pt[-2]\n        Decoding_Latency = (inference_Time_model_generate-Prefill_time) / (generated_token_number-generated_token_number_at_first_step)\n        Troughput = (generated_token_number/inference_Time_model_generate)\n        print(\"inference_Time_model_generate: \", inference_Time_model_generate)\n        print(\"Prefill_time: \", Prefill_time)\n        print('generated_token_number: ', generated_token_number)\n        print(\"max_generated_token_number: \", max_generated_token_number)\n        print(\"min_generated_token_number: \", min_generated_token_number)\n        print(\"Decoding_Latency: \", Decoding_Latency)\n        print(\"Troughput: \", Troughput)\n        print(\"acceptance rate: \", accept_num/guess_num)\n\n#-------------------------begin evaluation------------------------------\ndef benchmark_code(model, log_file, max_seqlen=512, assistant_model=None, sample=False, temperature=None, stopping=None):\n    \"\"\"put all the relevant function into this benchmark function\"\"\"\n    def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False):\n        \"\"\"\n        Writes an iterable of dictionaries to jsonl\n        \"\"\"\n        if append:\n            mode = 'ab'\n        else:\n            mode = 'wb'\n        filename = os.path.expanduser(filename)\n        if filename.endswith(\".gz\"):\n            with open(filename, mode) as fp:\n                with gzip.GzipFile(fileobj=fp, mode='wb') as gzfp:\n                    for x in data:\n                        gzfp.write((json.dumps(x) + \"\\n\").encode('utf-8'))\n        else:\n            with open(filename, mode) as fp:\n                for x in data:\n                    fp.write((json.dumps(x) + \"\\n\").encode('utf-8'))\n\n    def clip_input(tokenizer, prompt, max_new_tokens=512, max_seql=4096):\n        system_prompt = \"# python code to complete some task. # Create a function to calculate the sum of a sequence of integers. [PYTHON]\\ndef sum_sequence(sequence):\\n  sum = 0\\n  for num in sequence:\\n    sum += num\\n  return sum \\n[/PYTHON]\\n#\"\n        prompt = prompt['prompt']\n        prompt = system_prompt + prompt\n        input_ids = tokenizer(prompt,return_tensors='pt').input_ids\n        if len(input_ids[0])+max_new_tokens>=max_seql:\n            print('(input ids+max token)> {}'.format(max_seql))\n            sample_num = (len(input_ids[0])+max_new_tokens-max_seql) \n            input_ids = torch.cat((input_ids[0][:2],input_ids[0][2:-3][:-sample_num],input_ids[0][-3:]),dim=0).unsqueeze(0)\n        return  input_ids\n\n    # humaneval data\n    def get_humaneval(path=None):\n        if path != None:\n            with open(path, 'r') as json_file:\n                humaneval_data = json.load(json_file)\n            return humaneval_data\n        else:\n            from datasets import load_dataset\n            prompt_data = load_dataset(\"openai_humaneval\")\n            return prompt_data\n\n    def count_indent(text: str) -> int:\n        count = 0\n        for char in text:\n            if char == \" \":\n                count += 1\n            else:\n                break\n        return count\n\n\n    def fix_indents(text: str, multiple: int = 2):\n        outputs = []\n        for line in text.split(\"\\n\"):\n            while count_indent(line) % multiple != 0:\n                line = \" \" + line\n            outputs.append(line)\n        return \"\\n\".join(outputs)\n\n\n    def filter_code(completion: str, model=None) -> str:\n        completion = completion.lstrip(\"\\n\")\n        return completion.split(\"\\n\\n\")[0]\n    \n\n    testloader = get_humaneval()[\"test\"]\n\n    task_name = \"humaneval\"\n    big_reorg_dict = defaultdict(list)\n\n    for i, prompt in enumerate(testloader):\n        task_id = prompt[\"task_id\"]\n        input_ids = clip_input(model.tokenizer, prompt).to(model.device)\n        logging.critical(f\"[PROFILE] tokenizer:\")\n\n        start = time.perf_counter()\n        generate_ids = model.generate(input_ids, do_sample=sample, max_new_tokens=max_seqlen, pad_token_id=model.tokenizer.eos_token_id,\n                temperature=temperature,\n                stopping_criteria=stopping,\n                top_k=10, top_p=0.95,\n                assistant_model=assistant_model)\n        end = time.perf_counter()\n        generate_time = (end - start)\n        prompt_tokens = input_ids.shape[1]\n        num_tokens_out = generate_ids.shape[1]\n        new_tokens_generated = num_tokens_out - prompt_tokens\n        time_per_token = (generate_time/new_tokens_generated)*1e3\n        logging.critical(f\"[PROFILE] generate: {generate_time} for {num_tokens_out} tokens; prompt-tokens: {prompt_tokens}; time per generated token: {time_per_token}\")\n        completion = model.tokenizer.decode(generate_ids[0, input_ids.shape[1] : ])\n        completion = filter_code(fix_indents(completion))\n        val_completion = completion\n        print(f\"response: {val_completion}\")\n        logging.critical(f\"response: {val_completion}\")\n\n        #-------------------------below is results store-------------------------------------\n        comp_item = dict(task_id=\"{}\".format(task_id), completion=val_completion)\n        big_reorg_dict[\"yxl\"].append(comp_item)\n\n    \n    result_data_path = './results/codellama_spec_{}_on_mi250'.format(task_name)\n    if not os.path.exists(result_data_path):\n        os.makedirs(result_data_path)\n\n    for key, completion in big_reorg_dict.items():\n        write_jsonl(\"{0}/{1}.jsonl\".format(result_data_path, key), completion)\n\n    # similar to other benchmark methods\n    logging.shutdown()\n    out_file = log_file.replace(\".log\", \"_profile.csv\")\n    out_file = open(out_file, \"w\")\n    ProfileLLM.analyze_profiling(log_file, out_file)\n    out_file.close()\n\n\ndef warmup(model, prompts, max_new_tokens=30):\n    print(f\"Warming up ... \")\n    for prompt in prompts[0:1]:\n        inputs = model.tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n        generate_ids = model.generate(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=max_new_tokens)\n        _ = model.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n    print(f\"Warm up DONE!! \")\n\ndef main():\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--target_model\", type=str, default=\"target_model\", help=\"Target model path\")\n    parser.add_argument(\"--draft_model\", type=str, default=None, help=\"Draft model path\")\n    parser.add_argument(\"--max_new_tokens\", type=int, default=512, help=\"Maximum new tokens\")\n    parser.add_argument(\"--do_sample\", action='store_true', help=\"Whether to use sampling\")\n    parser.add_argument(\"--temperature\", type=float, default=0.0, help=\"Temperature for sampling\")\n    parser.add_argument(\"--seed\", type=int, default=42, help=\"Seed for reproduction\")\n    parser.add_argument(\"--device\", type=str, default=\"cuda\", help=\"Device for models\")\n    parser.add_argument(\"--bf16\", action='store_true', help=\"dtype for models\")\n    args = parser.parse_args()\n\n    torch.manual_seed(args.seed)\n    np.random.seed(args.seed)\n\n    codellama_checkpoint = args.target_model\n    assistant_checkpoint = args.draft_model\n\n    device = torch.device(args.device)\n    torch_dtype = torch.bfloat16 if args.bf16 else torch.float32\n    \n    tokenizer = AutoTokenizer.from_pretrained(codellama_checkpoint)\n    model = LlamaModelEval.from_pretrained(codellama_checkpoint, torch_dtype=torch_dtype).to(device)\n    model.tokenizer = tokenizer\n\n    sample = args.do_sample\n    use_spec = args.draft_model != None\n    temperature = args.temperature\n    log_dir = \"./logs\"\n    if not os.path.exists(log_dir):\n        os.makedirs(log_dir)\n\n    if use_spec:\n        log_file = log_dir + \"/log_codellama_spec.log\"\n    else:\n        log_file = log_dir + \"/log_codellama_target.log\"\n\n    logging.basicConfig(filename=log_file,\n                        filemode='w',\n                        level=logging.CRITICAL)\n\n\n    warmup_prompts = [\"from typing import List\\n\\n\\ndef parse_nested_parens(paren_string: str) -> List[int]:\\n    \\\"\\\"\\\" Input to this function is a string represented multiple groups for nested parentheses separated by spaces.\\n    For each of the group, output the deepest level of nesting of parentheses.\\n    E.g. (()()) has maximum two levels of nesting while ((())) has three.\\n\\n    >>> parse_nested_parens('(()()) ((())) () ((())()())')\\n    [2, 3, 1, 3]\\n    \\\"\\\"\\\"\\n\"]\n    warmup(model, warmup_prompts)\n    if use_spec:\n        assistant_model = LlamaModelEval_Draft.from_pretrained(assistant_checkpoint, torch_dtype=torch_dtype).to(device)\n        benchmark_code(model, log_file, assistant_model=assistant_model, sample=sample, temperature=temperature)\n    else:\n        benchmark_code(model, log_file, sample=sample, temperature=temperature)\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "speculative_decoding/codellama_spec.sh",
    "content": "# Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\ntarget_model=/path/to/target/model\ndraft_model=/path/to/draft/model\nmax_new_tokens=512\n\npython ./speculative_decoding/codellama_spec.py \\\n    --target_model $target_model\\\n    --draft_model $draft_model\\\n    --max_new_tokens $max_new_tokens \\\n    --temperature 0.1 \\\n    --do_sample \\\n    --bf16 \\\n"
  },
  {
    "path": "speculative_decoding/utils.patch",
    "content": "@@ -92,6 +92,7 @@\n if is_accelerate_available():\n     from accelerate.hooks import AlignDevicesHook, add_hook_to_module\n \n+import logging\n \n @dataclass\n class GenerateDecoderOnlyOutput(ModelOutput):\n@@ -4415,6 +4416,7 @@\n             # is no match.\n \n             # 4.1. Get the valid continuation, after the matching tokens\n+            logging.critical(f\"[PROFILE] valid_tokens {valid_tokens.shape[1]} , n_matches {n_matches}\")\n             input_ids = torch.cat((input_ids, valid_tokens), dim=-1)\n             if streamer is not None:\n                 streamer.put(valid_tokens.cpu())\n"
  }
]