Full Code of jzhang38/TinyLlama for AI

main bf122247c486 cached
37 files
317.8 KB
84.0k tokens
296 symbols
1 requests
Download .txt
Showing preview only (331K chars total). Download the full file or copy to clipboard to get everything.
Repository: jzhang38/TinyLlama
Branch: main
Commit: bf122247c486
Files: 37
Total size: 317.8 KB

Directory structure:
gitextract_67kobddf/

├── .gitignore
├── EVAL.md
├── LICENSE
├── PRETRAIN.md
├── README.md
├── README_zh-CN.md
├── chat_gradio/
│   ├── README.md
│   ├── app.py
│   └── requirements.txt
├── lit_gpt/
│   ├── __init__.py
│   ├── adapter.py
│   ├── adapter_v2.py
│   ├── config.py
│   ├── fused_cross_entropy.py
│   ├── fused_rotary_embedding.py
│   ├── lora.py
│   ├── model.py
│   ├── packed_dataset.py
│   ├── rmsnorm.py
│   ├── speed_monitor.py
│   ├── tokenizer.py
│   └── utils.py
├── pretrain/
│   ├── tinyllama.py
│   └── tinyllama_code.py
├── requirements.txt
├── script.sh
├── scripts/
│   ├── convert_hf_checkpoint.py
│   ├── convert_lit_checkpoint.py
│   ├── prepare_redpajama.py
│   ├── prepare_slimpajama.py
│   └── prepare_starcoder.py
├── sft/
│   ├── finetune.py
│   ├── script.sh
│   ├── simple_inference.py
│   └── simple_inference2.py
└── speculative_decoding/
    ├── README.md
    └── instruct_hf_assisted_decoding.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
__pycache__
.idea
.DS_Store
*.egg-info
build
.venv
.vscode

# data
data
checkpoints
out
wandb

tests/original_falcon_40b.py
sft/output
sft/wandb

================================================
FILE: EVAL.md
================================================
## Evaluate TinyLlama

### GPT4All Benchmarks

We evaluate TinyLlama's commonsense reasoning ability following the [GPT4All](https://gpt4all.io/index.html) evaluation suite. We include Pythia as our baseline. We report the acc_norm by default. 

Base models:

| Model                                     | Pretrain Tokens | HellaSwag | Obqa | WinoGrande | ARC_c | ARC_e | boolq | piqa | avg |
|-------------------------------------------|-----------------|-----------|------|------------|-------|-------|-------|------|-----|
| Pythia-1.0B                               |        300B     | 47.16     | 31.40| 53.43      | 27.05 | 48.99 | 60.83 | 69.21 | 48.30 |
| TinyLlama-1.1B-intermediate-step-50K-104b |        103B     | 43.50     | 29.80| 53.28      | 24.32 | 44.91 | 59.66 | 67.30 | 46.11|
| TinyLlama-1.1B-intermediate-step-240k-503b|        503B     | 49.56     |31.40 |55.80       |26.54  |48.32  |56.91  |69.42  | 48.28 |
| TinyLlama-1.1B-intermediate-step-480k-1007B |     1007B     | 52.54     | 33.40 | 55.96      | 27.82 | 52.36 | 59.54 | 69.91 | 50.22 |
| TinyLlama-1.1B-intermediate-step-715k-1.5T |     1.5T     | 53.68     | 35.20 | 58.33      | 29.18 | 51.89 | 59.08 | 71.65 | 51.29 |
| TinyLlama-1.1B-intermediate-step-955k-2T |     2T     | 54.63     | 33.40 | 56.83      | 28.07 | 54.67 | 63.21 | 70.67 | 51.64 |
| TinyLlama-1.1B-intermediate-step-1195k-2.5T  |     2.5T     | 58.96     | 34.40 | 58.72      | 31.91 | 56.78 | 63.21 | 73.07 | 53.86|
| TinyLlama-1.1B-intermediate-step-1431k-3T  |     3T     | 59.20     | 36.00 | 59.12      | 30.12 | 55.25 | 57.83 | 73.29 | 52.99|


Chat models:
| Model                                     | Pretrain Tokens | HellaSwag | Obqa | WinoGrande | ARC_c | ARC_e | boolq | piqa | avg |
|-------------------------------------------|-----------------|-----------|------|------------|-------|-------|-------|------|-----|
| [TinyLlama-1.1B-Chat-v0.1](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.1)                 |   503B     | 53.81     |32.20 | 55.01  | 28.67 |49.62  | 58.04 | 69.64 | 49.57 |
| [TinyLlama-1.1B-Chat-v0.2](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.2)                 |   503B     | 53.63     |32.80 | 54.85  | 28.75 |49.16  | 55.72 | 69.48 | 49.20 |
| [TinyLlama-1.1B-Chat-v0.3](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.3)                 |   1T       | 56.81     |34.20 | 55.80  | 30.03 |53.20  | 59.57 | 69.91 | 51.36 |
| [TinyLlama-1.1B-Chat-v0.4](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4)             |   1.5T     | 58.59     |35.40 | 58.80  | 30.80 |54.04  | 57.31 | 71.16 | 52.30 |


We observed huge improvements once we finetuned the model. We attribute this phenomenon to: 1. the base model has not undergone lr cool-down and FT helps to cool down the lr. 2. the SFT stage better elicits the model's internal knowledge.

You can obtain the above scores by running [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness):
```bash
python main.py \
    --model hf-causal \
    --model_args pretrained=PY007/TinyLlama-1.1B-Chat-v0.1,dtype="float" \
    --tasks hellaswag,openbookqa,winogrande,arc_easy,arc_challenge,boolq,piqa\
    --device cuda:0 --batch_size 32
```



### Instruct-Eval Benchmarks
We evaluate TinyLlama's ability in problem-solving on the [Instruct-Eval](https://github.com/declare-lab/instruct-eval) evaluation suite. 


| Model                                             | MMLU  | BBH   | HumanEval | DROP  |
| ------------------------------------------------- | ----- | ----- | --------- | ----- |
| Pythia-1.0B                                       | 25.70 | 28.19 | 1.83      | 4.25  |
| TinyLlama-1.1B-intermediate-step-50K-104b         | 26.45 | 28.82 | 5.49      | 11.42 |
| TinyLlama-1.1B-intermediate-step-240k-503b        | 26.16 | 28.83 | 4.88      | 12.43 |
| TinyLlama-1.1B-intermediate-step-480K-1T          | 24.65 | 29.21 | 6.1       | 13.03 |
| TinyLlama-1.1B-intermediate-step-715k-1.5T        | 24.85 | 28.2  | 7.93      | 14.43 |
| TinyLlama-1.1B-intermediate-step-955k-2T          | 25.97 | 29.07 | 6.71      | 13.14 |
| TinyLlama-1.1B-intermediate-step-1195k-token-2.5T | 25.92 | 29.32 | 9.15      | 15.45 |

You can obtain above scores by running [instruct-eval](https://github.com/declare-lab/instruct-eval):
```bash
CUDA_VISIBLE_DEVICES=0 python main.py mmlu --model_name llama --model_path PY007/TinyLlama-1.1B-intermediate-step-480K-1T
CUDA_VISIBLE_DEVICES=1 python main.py bbh --model_name llama --model_path PY007/TinyLlama-1.1B-intermediate-step-480K-1T
CUDA_VISIBLE_DEVICES=2 python main.py drop --model_name llama --model_path PY007/TinyLlama-1.1B-intermediate-step-480K-1T
CUDA_VISIBLE_DEVICES=3 python main.py humaneval  --model_name llama  --n_sample 1 --model_path PY007/TinyLlama-1.1B-intermediate-step-480K-1T


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [2023] Lightning AI

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: PRETRAIN.md
================================================
## Pretrain TinyLlama

### Installation
We expect you have CUDA 11.8 installed.
#### Install Pytorch Nightly.
```bash
pip install --index-url https://download.pytorch.org/whl/nightly/cu118 --pre 'torch>=2.1.0dev'
```
#### Build XFormers from Source
Note: as of 2023/09/02, xformers does not provide pre-built binaries for torch 2.1. You have to build it from source.
```bash
pip uninstall ninja -y && pip install ninja -U
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
```


#### Install Flash-Attention 2 and other fused operators:
```bash
git clone https://github.com/Dao-AILab/flash-attention
cd flash-attention
python setup.py install
cd csrc/rotary && pip install .
cd ../layer_norm && pip install .
cd ../xentropy && pip install .
cd ../.. && rm -rf flash-attention
```
#### Install Remaining Dependencies
```
pip install -r requirements.txt tokenizers sentencepiece
```
to install other dependencies.
It may take >= 5 minutes to build xformers/flash-attention. Do not worry if the process seemly stagnant or the terminal print out many warnings.

Then you are ready to go 🎉!

### Data Preparation

#### Download Datasets
Download the Slimpajama and Starcoderdata datasets to your chosen directory.
```bash
cd /path/to/dataset
git lfs install
git clone https://huggingface.co/datasets/cerebras/SlimPajama-627B
git clone https://huggingface.co/datasets/bigcode/starcoderdata
```
The SlimPajama dataset eats 893GB diskspace and the starcoderdata takes 290GB.

#### Tokenize data
Use the provided scripts to tokenize the datasets and divide them into chunks.
```bash
python scripts/prepare_starcoder.py --source_path /path/to/starcoderdata/ --tokenizer_path data/llama --destination_path data/slim_star_combined --split train --percentage 1.0
python scripts/prepare_slimpajama.py --source_path /path/to/SlimPajama --tokenizer_path data/llama  --destination_path data/slim_star_combined --split validation --percentage 1.0
python scripts/prepare_slimpajama.py --source_path /path/to/SlimPajama --tokenizer_path data/llama  --destination_path data/slim_star_combined --split train --percentage 1.0
```
The processed data will take 1.8T storage.

### Pretraining
If your setup comprises two nodes, each with 8 GPUs, you can initiate pretraining with the following commands:

On node 1:
```
lightning run model \
    --node-rank=0  \
    --main-address=172.16.101.5 \
    --accelerator=cuda \
    --devices=8 \
    --num-nodes=2 \
    pretrain/tinyllama.py --devices 8 --train_data_dir data/slim_star  --val_data_dir data/slim_star
```
On node 2:
```
lightning run model \
    --node-rank=1  \
    --main-address=172.16.101.5 \
    --accelerator=cuda \
    --devices=8 \
    --num-nodes=2 \
    pretrain/tinyllama.py --devices 8 --train_data_dir data/slim_star   --val_data_dir data/slim_star
```
You can follow [these instructions](https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html) if you have a slurm cluster.



================================================
FILE: README.md
================================================
<div align="center">

# TinyLlama-1.1B
English | [中文](README_zh-CN.md)

[Chat Demo](https://huggingface.co/spaces/TinyLlama/tinyllama-chat) | [Discord](https://discord.gg/74Wcx4j5Nb)
</div>

The TinyLlama project aims to **pretrain** a **1.1B Llama model on 3 trillion tokens**. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01. 

<div align="center">
  <img src=".github/TinyLlama_logo.png" width="300"/>
</div>

We adopted exactly the same architecture and tokenizer as Llama 2. This means TinyLlama can be plugged and played in many open-source projects built upon Llama. Besides, TinyLlama is compact with only 1.1B parameters. This compactness allows it to cater to a multitude of applications demanding a restricted computation and memory footprint.

#### News
- 2023-12-18: Add two notes [1](https://whimsical-aphid-86d.notion.site/Release-of-TinyLlama-1-5T-Checkpoints-Postponed-01b266998c1c47f78f5ae1520196d194?pvs=4), [2](https://whimsical-aphid-86d.notion.site/Latest-Updates-from-TinyLlama-Team-7d30c01fff794da28ccc952f327c8d4f?pvs=4) explaining the changes of training curves, project schedules, and bug fixes.
- 2023-10-03: Add examples in speculative decoding with llama.cpp. Do check out the [speculative_decoding/README.md](speculative_decoding/README.md).
- 2023-10-02: 1. 1T-token checkpoint just dropped. 2. We document **all** intermediate checkpoints [here](https://huggingface.co/TinyLlama/tinyLlama-intermediate-checkpoints/tree/step-480k-token-1007B).
- 2023-09-28: Add a discord server.
- 2023-09-18: 1. We added a [chat demo](https://huggingface.co/spaces/PY007/TinyLlama-Chat) so that you can play with TinyLlama-Chat-V0.1 right away. 
- 2023-09-16: 1. We released the intermediate checkpoint trained on 503B tokens. 2. We released a chat model finetuned on OpenAssisant and simple [finetuning](sft) scripts is added. 3. More eval benchmarks are added and documented in [EVAL.md](EVAL.md). 

#### Evaluation
You can find the evaluation results of TinyLlama in [EVAL.md](EVAL.md).

#### Releases Schedule
We will be rolling out intermediate checkpoints following the below schedule. 

Base models:

| Date       | HF Checkpoint                                   | Tokens | Step | Commonsense Avg |
|------------|-------------------------------------------------|--------|------| --------------- |
| 2023-09-01 | Pythia-1.0B                                     | 300B   | 143k   | 48.30 |
| 2023-09-04 | [TinyLlama-1.1B-intermediate-step-50k-105b](https://huggingface.co/PY007/TinyLlama-1.1B-step-50K-105b) | 105B   | 50k   | 46.11|
| 2023-09-16 | [TinyLlama-1.1B-intermediate-step-240k-503b](https://huggingface.co/PY007/TinyLlama-1.1B-intermediate-step-240k-503b)                                            | 503B   | 240K    | 48.28 |
| 2023-10-01 | [TinyLlama-1.1B-intermediate-step-480k-1T](https://huggingface.co/PY007/TinyLlama-1.1B-intermediate-step-480k-1T) | 1T     | 480k | 50.22 |
| 2023-11-04 | [TinyLlama-1.1B-intermediate-step-715k-1.5T](https://huggingface.co/PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T)                                            | 1.5T     |715k    |51.28 |
| 2023-11-20 | [TinyLlama-1.1B-intermediate-step-955k-2T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T)                                            | 2T     |955k    |51.64 |
| 2023-12-11 | [TinyLlama-1.1B-intermediate-step-1195k-2.5T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T)              | 2.5T     | 1195k    |53.86 |
| 2023-12-28 | [TinyLlama-1.1B-intermediate-step-1431k-3T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T)              | 3T   | 1431k  | 52.99 |

We are crafting a note offering possible explaination on why there is a significant improvement from 2T to 2.5T checkpoint (It is related to [bos_id issue](https://github.com/jzhang38/TinyLlama/issues/83))

Chat models:

| Date       | HF Checkpoint                                   | Tokens | Step | Commonsense Avg |
|------------|-------------------------------------------------|--------|------| --------------- |
| 2023-09-16 | [TinyLlama-1.1B-Chat-V0.1](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.1)                                            | 503B   | 240K    |  49.57 |
| 2023-10-1 | [TinyLlama-1.1B-Chat-V0.3](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.3)                                            | 1T   | 480K    |  51.36 |
| 2023-11-04 | [TinyLlama-1.1B-Chat-V0.4](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4)                                            | 1.5T   | 715K    |  52.30 |

Note that the learning rate of the base model has not cooled down yet so we recommend you to also use the finetuned chat model.

Meanwhile, you can track the live cross entropy loss [here](https://wandb.ai/lance777/lightning_logs/reports/metric-train_loss-23-09-04-23-38-15---Vmlldzo1MzA4MzIw?accessToken=5eu2sndit2mo6eqls8h38sklcgfwt660ek1f2czlgtqjv2c6tida47qm1oty8ik9).

## Potential Usecase
Tiny but strong language models are useful for many applications. Here are some potential usecases:
- Assisting speculative decoding of larger models. (See this [tutorial](https://twitter.com/karpathy/status/1697318534555336961) by Andrej Karpathy)
- Deployment on edge devices with restricted memory and computational capacities, for functionalities like real-time machine translation without an internet connection (the 4bit-quantized TinyLlama-1.1B's weight only takes up 637 MB).
- Enabling real-time dialogue generation in video games.

Moreover, our code can be a **reference for enthusiasts keen on pretraining language models under 5 billion parameters** without diving too early into [Megatron-LM](https://github.com/NVIDIA/Megatron-LM).

## Training Details
Below are some details of our training setup:

| Setting                         | Description                                                    |
|---------------------------------|----------------------------------------------------------------|
| Parameters                      | 1.1B                                                           |
| Attention Variant               | Grouped Query Attention                                        |
| Model Size                      | Layers: 22, Heads: 32, Query Groups: 4, Embedding Size: 2048, Intermediate Size (Swiglu): 5632|
| Sequence Length                 | 2048                                                           |
| Batch Size                      | 2 million tokens (2048 * 1024)                                             |
| Learning Rate                   | 4e-4                                                           |
| Learning Rate Schedule          | Cosine with 2000 warmup steps. See [Issue 27](https://github.com/jzhang38/TinyLlama/issues/27) for a minor bug     |
| Training Data                   | [Slimpajama](https://huggingface.co/datasets/cerebras/slimpajama-627b) & [Starcoderdata](https://huggingface.co/datasets/bigcode/starcoderdata) |
| Data Preprocessing              | Excluded GitHub subset of Slimpajama; Sampled all code from Starcoderdata |
| Combined Dataset Size           | Around 950B tokens                                              |
| Total Tokens During Training    | 3 trillion (slightly more than 3 epochs/1430k steps)                                          |
| Natural Language to Code Ratio  | 7:3                                                            |
| Hardware                        | 16 A100-40G GPUs                                               |






## Blazingly Fast
Our codebase supports the following features:
- multi-gpu and multi-node distributed training with FSDP.
- flash attention 2.
- fused layernorm.
- fused swiglu.
- fused cross entropy loss .
- fused rotary positional embedding.

Credit: flash attention 2, fused layernorm, fused cross entropy loss, and fused
rotary positional embedding are from the [FlashAttention repo](https://github.com/Dao-AILab/flash-attention/). Fused swiglu is from [xformers](https://github.com/facebookresearch/xformers).

Thanks to those optimizations, we achieve a throughput of **24k** tokens per second per A100-40G GPU, which translates to **56% model flops utilization** without activation checkpointing (We expect the MFU to be even higher on A100-80G). It means you can train a chinchilla-optimal TinyLlama (1.1B param, 22B tokens) in **32 hours with 8 A100**. Those optimizations also greatly reduce the memory footprint, allowing us to stuff our 1.1B model into 40GB GPU RAM and train with a per-gpu batch size of 16k tokens. **You can also pretrain TinyLlama on 3090/4090 GPUs with a smaller per-gpu batch size**.
Below is a comparison of the training speed of our codebase with that of Pythia and MPT.


| Model                             | A100 GPU hours taken on 300B tokens| 
|-----------------------------------|------------------------------------|
|TinyLlama-1.1B                     | 3456                               |    
|[Pythia-1.0B](https://huggingface.co/EleutherAI/pythia-1b)                        | 4830                               |
|[MPT-1.3B](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b)                           | 7920                               |  

<small> The Pythia number comes from their [paper](https://arxiv.org/abs/2304.01373). The MPT number comes from [here](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b), in which they say MPT-1.3B " was trained on 440 A100-40GBs for about half a day" on 200B tokens. </small>

The fact that TinyLlama is a relatively small model with grouped query attention means it is also fast during inference. Below are some throughputs that we measure:

| Framework | Device | Settings | Throughput (tokens/sec) |
|-----------|--------------|-----|-----------|
|[Llama.cpp](https://github.com/ggerganov/llama.cpp) | Mac M2 16GB RAM         |  batch_size=1; 4-bit inference|    71.8     | 
|[vLLM](https://github.com/vllm-project/vllm)       | A40 GPU  | batch_size=100, n=10 |   7094.5         |


## Pretrain
Please refer to [PRETRAIN.md](PRETRAIN.md) for instructions on how to pretrain TinyLlama.

## Finetune
We include a simple full-parameter finetuning & inference script in [sft](sft). Our V0.1 chat model is finetuned using this script. The FT dataset we use is [openassistant-guanaco](https://huggingface.co/datasets/timdettmers/openassistant-guanaco). 
For finetuning with less than 4GB RAM, we refer you to the [Qlora](https://github.com/artidoro/qlora) and [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) repos.
We did not undergo extensive hyperparameter tuning nor choose more performant FT datasets. We hope the community can explore on finetuning TinyLlama and come up with better chat models. I will include community-finetuned models in this repo.

## TODO
This project is still under active development. We are a really small team. Community feedback and contributions are highly appreciated. Here are some things we plan to work on:
 - [ ] Add scripts for pretraining on other datasets.
 - [ ] Sequence length extrapolation.
 - [ ] Test out speculative decoding for Llama-2-7B.
 - [ ] Test the throughput on RTX 3090/4090. 
 - [ ] Add fine-tuning scripts.
 - [ ] Properly evaluate the model on downstream tasks.
 - [ ] A demo running on mobile phones. 
 - [ ] Explore retrieval-augmentation.



## Acknowledgements
This repository is built upon [lit-gpt](https://github.com/Lightning-AI/lit-gpt) and [flash-attention](https://github.com/Dao-AILab/flash-attention). Be sure to explore this fantastic open-source project if it's new to you!
```
@online{lit-gpt,
  author    = {Lightning AI},
  title     = {Lit-GPT},
  url       = {https://github.com/Lightning-AI/lit-gpt},
  year      = {2023},
}
@article{dao2023flashattention2,
  title     ={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
  author    ={Dao, Tri},
  year      ={2023}
}
```

## Citation
This project is currently contributed by [Peiyuan Zhang](https://veiled-texture-20c.notion.site/Peiyuan-Zhang-ab24b48621c9491db767a76df860873a?pvs=4) *, [Guangtao Zeng](https://github.com/ChaosCodes) *, [Tianduo Wang](https://github.com/TianduoWang) and [Wei Lu](https://istd.sutd.edu.sg/people/faculty/lu-wei/) from the StatNLP Research Group of Singapore University of Technology and Design. 

If you find our work valuable, please cite:

```
@misc{zhang2024tinyllama,
      title={TinyLlama: An Open-Source Small Language Model}, 
      author={Peiyuan Zhang and Guangtao Zeng and Tianduo Wang and Wei Lu},
      year={2024},
      eprint={2401.02385},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
```

## Frequently Asked Questions

#### 1. Why would pretraining a 1.1B model for so long make sense? Doesn't it contradict the Chinchilla Scaling Law?

<img src=".github/llama2-training.png" alt="The training loss curve of Llama 2" width="500"/>

Above is the training loss curve taken from the Llama 2 paper. Here I quote from that paper: "We observe that after pretraining on 2T Tokens, the models still did not show any sign of saturation". That is why we believe pretraining a 1.1B model for 3T tokens is a reasonable thing to do. Even if the loss curve does not go down eventually, we can still study the phenomenon of saturation and learn something from it.

#### 2. What does "saturation" mean?
<img src=".github/Pythia_saturation.png" alt="Figure 10 of the Pythia paper" width="500"/>

The figure from the Pythia paper displays the LAMBADA accuracy plotted against the total training tokens (300B). The term "saturation" pertains specifically to the 70M and 160M models. Notably, even the 410M model does not saturate with 300B tokens, as it continues to show an increasing trend, similar to the trend of larger models. 


## Star History

[![Star History Chart](https://api.star-history.com/svg?repos=jzhang38/TinyLlama&type=Date)](https://star-history.com/#jzhang38/TinyLlama&Date)



================================================
FILE: README_zh-CN.md
================================================
<div align="center">

# TinyLlama-1.1B
[English](README.md) | 中文

[Chat Demo](https://huggingface.co/spaces/TinyLlama/tinyllama-chat)
</div>

TinyLlama项目旨在在3万亿tokens上进行预训练,构建一个拥有11亿参数的Llama模型。经过精心优化,我们"仅"需16块A100-40G的GPU,便可在90天内完成这个任务🚀🚀。训练已于2023-09-01开始。


<div align="center">
  <img src=".github/TinyLlama_logo.png" width="300"/>
</div>
我们采用了与Llama 2完全相同的架构和分词器。这意味着TinyLlama可以在许多基于Llama的开源项目中即插即用。此外,TinyLlama只有1.1B的参数,体积小巧,适用于需要限制计算和内存占用的多种应用。

#### 新闻

* 2023-12-18:
  * 添加两个文档 [1](https://whimsical-aphid-86d.notion.site/Release-of-TinyLlama-1-5T-Checkpoints-Postponed-01b266998c1c47f78f5ae1520196d194?pvs=4), [2](https://whimsical-aphid-86d.notion.site/Latest-Updates-from-TinyLlama-Team-7d30c01fff794da28ccc952f327c8d4f?pvs=4) 说明训练曲线、项目时间表和错误修复的变化。
* 2023-10-03: 
  * 在speculative decoding中添加llama.cpp的代码示例。具体请查看 [speculative_decoding/README.md](speculative_decoding/README.md)。
  * 2023-10-02: 1. 1T-token检查点刚发布。2. 我们在[huggingface](https://huggingface.co/TinyLlama/tinyLlama-intermediate-checkpoints/tree/step-480k-token-1007B)上记录了**所有**中间检查点。
  * 2023-09-28: 启用[Discord](https://discord.gg/74Wcx4j5Nb)服务器。
* 2023-09-18: 
  * 发布了一个 [chat demo](https://huggingface.co/spaces/TinyLlama/tinyllama-chat),欢迎点击链接来尝试我们的模型。
* 2023-09-16: 
  * 发布了目前已经训练了 5.03 亿个 token 的 [checkpoints 模型](https://huggingface.co/PY007/TinyLlama-1.1B-intermediate-step-240k-503b)。 
  * 基于 5.03 亿 token 的 [checkpoints 模型](https://huggingface.co/PY007/TinyLlama-1.1B-intermediate-step-240k-503b) 在 OpenAssistant 数据集上微调并开源了聊天模型 [TinyLlama-Chat-V0.1](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.1) ,并添加了我们的 [微调脚本](sft) 。
  * 添加了更多的评测数据集,您可以通过 [EVAL.md](EVAL.md) 文件来查看我们各模型的结果。




#### 发布时间表

我们会根据以下计划逐步发布中间checkpoint。我们也列了一些基线模型进行比较。

基座模型:

| Date       | 模型权重                                              | Tokens | Step | Commonsense Avg |
| ---------- | ------------------------------------------------------------ | ------ | ---- | --------------- |
| 2023-09-01 | Pythia-1.0B                                                  | 300B   | 143k | 48.30           |
| 2023-09-04 | [TinyLlama-1.1B-intermediate-step-50k-105b](https://huggingface.co/PY007/TinyLlama-1.1B-step-50K-105b) ([ModelScope](https://www.modelscope.cn/models/chaoscodes/TinyLlama-1.1B-step-50K-105b/files)) | 105B   | 50k  | 46.11           |
| 2023-09-16 | [TinyLlama-1.1B-intermediate-step-240k-503b](https://huggingface.co/PY007/TinyLlama-1.1B-intermediate-step-240k-503b) ([ModelScope](https://www.modelscope.cn/models/chaoscodes/TinyLlama-1.1B-intermediate-step-240k-503b/files)) | 503B   | 240K | 48.28           |
| 2023-10-01 | [TinyLlama-1.1B-intermediate-step-480k-1T](https://huggingface.co/PY007/TinyLlama-1.1B-intermediate-step-480k-1T) | 1T     | 480k | 50.22 |
| 2023-11-04 | [TinyLlama-1.1B-intermediate-step-715k-1.5T](https://huggingface.co/PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T)                                            | 1.5T     |715k    |51.28 |
| 2023-11-20 | [TinyLlama-1.1B-intermediate-step-955k-2T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T)                                            | 2T     |955k    |51.64 |
| 2023-12-11 | [TinyLlama-1.1B-intermediate-step-1195k-2.5T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T)              | 2.5T     | 1195k    |53.86 |
| 2023-12-28 | [TinyLlama-1.1B-intermediate-step-1431k-3T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T)              | 3T   | 1431k  | 52.99 |

对话模型:

| Date       | 模型权重                                  | Tokens | Step | Commonsense Avg |
|------------|-------------------------------------------------|--------|------| --------------- |
| 2023-09-16 | [TinyLlama-1.1B-Chat-V0.1](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.1) ([ModelScope](https://www.modelscope.cn/models/chaoscodes/TinyLlama-1.1B-Chat-v0.1/files))                                         | 503B   | 240K    |  49.57 |
| 2023-10-1 | [TinyLlama-1.1B-Chat-V0.3](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.3)                                            | 1T   | 480K    |  51.36 |
| 2023-11-04 | [TinyLlama-1.1B-Chat-V0.4](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4)                                            | 1.5T   | 715K    |  52.30 |

需要注意的是,由于我们的现在模型还处于训练初期,学习率并没有完全稳定下来,为了更好的体验我们的模型,您可以下载我们 [聊天模型](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0) 或者通过 [chat demo](https://huggingface.co/spaces/TinyLlama/tinyllama-chat) 来尝试我们的模型。


你们也可以在[这里](https://api.wandb.ai/links/lance777/pgvhrsny)实时跟踪TinyLlama的训练损失。

## 潜在场景
小型但强大的语言模型对许多应用都很有用。以下是一些潜在的场景:
- 帮助对大型模型进行speculative decoding。
- 在边缘装置上运行,比如离线的实时机器翻译 (TinyLlama的4比特量化版本的模型权重只需要550MB的内存)。
- 在游戏中实现实时对话生成(因为还得给游戏本身留显存所以模型要小)。

此外,我们的代码可以给初学者做一个**入门预训练的简洁参考**。如果你要训练50亿以下参数的语言模型, 你其实不需要Megatron-LM。

## 训练细节
以下是我们训练设置的一些细节:

| Setting                         | Description                                                    |
|---------------------------------|----------------------------------------------------------------|
| Parameters                      | 1.1B                                                           |
| Attention Variant               | Grouped Query Attention                                        |
| Model Size                      | Layers: 22, Heads: 32, Query Groups: 4, Embedding Size: 2048, Intermediate Size (Swiglu): 5632|
| Sequence Length                 | 2048                                                           |
| Batch Size                      | 2 million tokens (2048 * 1024)                                             |
| Learning Rate                   | 4e-4                                                           |
| Learning Rate Schedule          | Cosine with 2000 warmup steps                                  |
| Training Data                   | [Slimpajama](https://huggingface.co/datasets/cerebras/slimpajama-627b) & [Starcoderdata](https://huggingface.co/datasets/bigcode/starcoderdata) |
| Data Preprocessing              | Excluded GitHub subset of Slimpajama; Sampled all code from Starcoderdata |
| Combined Dataset Size           | Around 950B tokens                                              |
| Total Tokens During Training    | 3 trillion (slightly more than 3 epochs/143k steps)                                          |
| Natural Language to Code Ratio  | 7:3                                                            |
| Hardware                        | 16 A100-40G GPUs                                               |






## 速度极快
我们的代码库支持以下特性:
- 使用FSDP进行多GPU和多节点分布式训练
- flash attention 2
- 融合层归一化 (fused layernorm)
- 融合swiglu (fused swiglu)
- 融合交叉熵损失 (fused cross entropy loss)
- 融合旋转位置嵌入 (fused rotary positional embedding)

致谢:flash attention 2、融合层归一化、融合交叉熵损失和融合旋转位置嵌入来自于[FlashAttention](https://github.com/Dao-AILab/flash-attention/)仓库;融合swiglu来自于[xformers](https://github.com/facebookresearch/xformers)。

有了这些优化, 我们可以达到**24k tokens/秒/A100**的训练速度,也就是56%的MFU(在A100-80G上的MFU会更高)。这个速度可以让你可以在**8个A100上用32小时训练一个chinchilla-optimial的模型**(11亿参数,220亿token)。这些优化也大大减少了显存占用, 我们可以把11亿参数的模型塞入40GB的GPU里面还能同时维持16k tokens的per-gpu batch size。只需要把batch size改小一点, 你就可以在**RTX 3090/4090**上面训练TinyLlama。
下面是我们的代码库与Pythia和MPT的训练速度的比较。


| Model                             | A100 GPU hours taken on 300B tokens| 
|-----------------------------------|------------------------------------|
|TinyLlama-1.1B                     | 3456                               |    
|[Pythia-1.0B](https://huggingface.co/EleutherAI/pythia-1b)                        | 4830                               |
|[MPT-1.3B](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b)                           | 7920                               |  

<small> Pythia的数字来自他们的论文。MPT的数字来自[这里](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b),作者说MPT-1.3B"was trained on 440 A100-40GBs for about half a day" on 200B tokens。</small>

TinyLlama是一个相对较小的模型, 同时我们用了GQA, 这意味着它在推理期间也很快。以下是我们测量的一些推理速度:

| Framework | Device | Settings | Throughput (tokens/sec) |
|-----------|--------------|-----|-----------|
|[Llama.cpp](https://github.com/ggerganov/llama.cpp) | Mac M2 16GB RAM         |  batch_size=1; 4-bit inference|    71.8     | 
|[vLLM](https://github.com/vllm-project/vllm)       | A40 GPU  | batch_size=100, n=10 |   7094.5         |


## 开始预训练
请参考[PRETRAIN.md](PRETRAIN.md)。



## 微调

* 我们在 [sft](sft) 中添加了我们进行微调和推理的代码。并且基于这个代码我们在[openassistant-guanaco](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) 数据集上进行了微调,得到了我们的第一版[聊天模型](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.1)。
* 如果您希望在 RAM 小于 4GB 的 GPU 上对用我们的模型进行微调,可以参考并使用 [Qlora](https://github.com/artidoro/qlora) 和 [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) 项目。
* 目前微调的时候我们并没有广泛对超参进行搜索,也没有选择潜在更优的 instruction 数据集。我们希望促进 NLP 社区对于我们的TinyLlama模型的开放研究,并开源更好的微调聊天模型。我们也会把这些模型放在这个项目中。



## TODO
该项目仍在积极开发中。我们团队很小,非常欢迎社区的反馈和贡献。以下是我们计划进行的一些工作:
 - [ ] Add scripts for pretraining on other datasets.
 - [ ] Sequence length extrapolation.
 - [ ] Test out speculative decoding for Llama-2-7B.
 - [ ] Test the throughput on RTX 3090/4090. 
 - [ ] Add fine-tuning scripts.
 - [ ] Properly evaluate the model on downstream tasks.
 - [ ] A demo running on mobile phones. 
 - [ ] Explore retrieval-augmentation.

## Star History

[![Star History Chart](https://api.star-history.com/svg?repos=jzhang38/TinyLlama&type=Date)](https://star-history.com/#jzhang38/TinyLlama&Date)


## Acknowledgements
这个仓库基于出色的开源项目[lit-gpt](https://github.com/Lightning-AI/lit-gpt)和[flash-attention](https://github.com/Dao-AILab/flash-attention)构建. 
```
@online{lit-gpt,
  author    = {Lightning AI},
  title     = {Lit-GPT},
  url       = {https://github.com/Lightning-AI/lit-gpt},
  year      = {2023},
}
@article{dao2023flashattention2,
  title     ={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
  author    ={Dao, Tri},
  year      ={2023}
}
```

## Citation
此项目目前由[Peiyuan Zhang](https://github.com/jzhang38),[Guangtao Zeng](https://github.com/ChaosCodes),[Tianduo Wang](https://github.com/TianduoWang)和[Wei Lu](https://istd.sutd.edu.sg/people/faculty/lu-wei/)贡献。 

如果您觉得我们的工作有价值, 可以引用:

```
@misc{zhang2024tinyllama,
      title={TinyLlama: An Open-Source Small Language Model}, 
      author={Peiyuan Zhang and Guangtao Zeng and Tianduo Wang and Wei Lu},
      year={2024},
      eprint={2401.02385},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
```



================================================
FILE: chat_gradio/README.md
================================================
## Tinyllama Chatbot Implementation with Gradio

We offer an easy way to interact with Tinyllama. This guide explains how to set up a local Gradio demo for a chatbot using TinyLlama.
(A demo is also available on the Hugging Face Space [TinyLlama/tinyllama_chatbot](https://huggingface.co/spaces/TinyLlama/tinyllama-chat)) or Colab [colab](https://colab.research.google.com/drive/1qAuL5wTIa-USaNBu8DH35KQtICTnuLsy?usp=sharing).

### Requirements
* Python>=3.8
* PyTorch>=2.0
* Transformers>=4.34.0
* Gradio>=4.13.0

### Installation
`pip install -r requirements.txt`

### Usage

`python TinyLlama/chat_gradio/app.py`

* After running it, open the local URL displayed in your terminal in your web browser. (For server setup, use SSH local port forwarding with the command: `ssh -L [local port]:localhost:[remote port] [username]@[server address]`.)
* Interact with the chatbot by typing questions or commands.


**Note:** The chatbot's performance may vary based on your system's hardware. Ensure your system meets the above requirements for optimal experience.


================================================
FILE: chat_gradio/app.py
================================================
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread

# Loading the tokenizer and model from Hugging Face's model hub.
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

# using CUDA for an optimal experience
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)


# Defining a custom stopping criteria class for the model's text generation.
class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [2]  # IDs of tokens where the generation should stop.
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:  # Checking if the last generated token is a stop token.
                return True
        return False


# Function to generate model predictions.
def predict(message, history):
    history_transformer_format = history + [[message, ""]]
    stop = StopOnTokens()

    # Formatting the input for the model.
    messages = "</s>".join(["</s>".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]])
                        for item in history_transformer_format])
    model_inputs = tokenizer([messages], return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        top_k=50,
        temperature=0.7,
        num_beams=1,
        stopping_criteria=StoppingCriteriaList([stop])
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()  # Starting the generation in a separate thread.
    partial_message = ""
    for new_token in streamer:
        partial_message += new_token
        if '</s>' in partial_message:  # Breaking the loop if the stop token is generated.
            break
        yield partial_message


# Setting up the Gradio chat interface.
gr.ChatInterface(predict,
                 title="Tinyllama_chatBot",
                 description="Ask Tiny llama any questions",
                 examples=['How to cook a fish?', 'Who is the president of US now?']
                 ).launch()  # Launching the web interface.


================================================
FILE: chat_gradio/requirements.txt
================================================
torch>=2.0
transformers>=4.35.0
gradio>=4.13.0


================================================
FILE: lit_gpt/__init__.py
================================================
from lit_gpt.model import GPT
from lit_gpt.config import Config
from lit_gpt.tokenizer import Tokenizer
from lit_gpt.fused_cross_entropy import FusedCrossEntropyLoss
from lightning_utilities.core.imports import RequirementCache

if not bool(RequirementCache("torch>=2.1.0dev")):
    raise ImportError(
        "Lit-GPT requires torch nightly (future torch 2.1). Please follow the installation instructions in the"
        " repository README.md"
    )
_LIGHTNING_AVAILABLE = RequirementCache("lightning>=2.1.0.dev0")
if not bool(_LIGHTNING_AVAILABLE):
    raise ImportError(
        "Lit-GPT requires Lightning nightly (future lightning 2.1). Please run:\n"
        f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}"
    )


__all__ = ["GPT", "Config", "Tokenizer"]


================================================
FILE: lit_gpt/adapter.py
================================================
"""Implementation of the paper:

LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
https://arxiv.org/abs/2303.16199

Port for Lit-GPT
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from typing_extensions import Self

from lit_gpt.config import Config as BaseConfig
from lit_gpt.model import GPT as BaseModel
from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention
from lit_gpt.model import KVCache, RoPECache, apply_rope


@dataclass
class Config(BaseConfig):
    adapter_prompt_length: int = 10
    adapter_start_layer: int = 2


class GPT(BaseModel):
    """The implementation is identical to `lit_gpt.model.GPT` with the exception that
    the `Block` saves the layer index and passes it down to the attention layer."""

    def __init__(self, config: Config) -> None:
        nn.Module.__init__(self)
        assert config.padded_vocab_size is not None
        self.config = config

        self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
                h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
                ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
            )
        )

        self.rope_cache: Optional[RoPECache] = None
        self.mask_cache: Optional[torch.Tensor] = None
        self.kv_caches: List[KVCache] = []
        self.adapter_kv_caches: List[KVCache] = []

    def reset_cache(self) -> None:
        super().reset_cache()
        self.adapter_kv_caches.clear()

    def forward(
        self,
        idx: torch.Tensor,
        max_seq_length: Optional[int] = None,
        input_pos: Optional[torch.Tensor] = None,
        lm_head_chunk_size: int = 0,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        B, T = idx.size()
        use_kv_cache = input_pos is not None

        block_size = self.config.block_size
        if max_seq_length is None:
            max_seq_length = block_size
        if use_kv_cache:  # not relevant otherwise
            assert (
                max_seq_length >= T
            ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
        assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
        assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}"

        if self.rope_cache is None:
            self.rope_cache = self.build_rope_cache(idx)
        # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask
        # for the kv-cache support (only during inference), we only create it in that situation
        # this will be resolved by https://github.com/pytorch/pytorch/issues/96099
        if use_kv_cache and self.mask_cache is None:
            self.mask_cache = self.build_mask_cache(idx)

        cos, sin = self.rope_cache
        if use_kv_cache:
            cos = cos.index_select(0, input_pos)
            sin = sin.index_select(0, input_pos)
            mask = self.mask_cache.index_select(2, input_pos)
            mask = mask[:, :, :, :max_seq_length]
        else:
            cos = cos[:T]
            sin = sin[:T]
            mask = None

        # forward the model itself
        x = self.transformer.wte(idx)  # token embeddings of shape (b, t, n_embd)

        if not use_kv_cache:
            for block in self.transformer.h:
                x, *_ = block(x, (cos, sin), max_seq_length)
        else:
            self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1))
            self.adapter_kv_caches = self.adapter_kv_caches or [None for _ in range(self.config.n_layer)]
            for i, block in enumerate(self.transformer.h):
                x, self.kv_caches[i], self.adapter_kv_caches[i] = block(
                    x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i], self.adapter_kv_caches[i]
                )

        x = self.transformer.ln_f(x)

        if lm_head_chunk_size > 0:
            # chunk the lm head logits to reduce the peak memory used by autograd
            return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]
        return self.lm_head(x)  # (b, t, vocab_size)

    @classmethod
    def from_name(cls, name: str, **kwargs: Any) -> Self:
        return cls(Config.from_name(name, **kwargs))

    def _init_weights(self, module: nn.Module) -> None:
        """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
        super()._init_weights(module)
        if isinstance(module, CausalSelfAttention):
            module.reset_parameters()


class Block(nn.Module):
    """The implementation is identical to `lit_gpt.model.Block` with the exception that
    we replace the attention layer where adaption is implemented."""

    def __init__(self, config: Config, block_idx: int) -> None:
        super().__init__()
        self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
        self.attn = CausalSelfAttention(config, block_idx)
        if not config.shared_attention_norm:
            self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
        self.mlp = config.mlp_class(config)

        self.config = config

    def forward(
        self,
        x: torch.Tensor,
        rope: RoPECache,
        max_seq_length: int,
        mask: Optional[torch.Tensor] = None,
        input_pos: Optional[torch.Tensor] = None,
        kv_cache: Optional[KVCache] = None,
        adapter_kv_cache: Optional[KVCache] = None,
    ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]:
        n_1 = self.norm_1(x)
        h, new_kv_cache, new_adapter_kv_cache = self.attn(
            n_1, rope, max_seq_length, mask, input_pos, kv_cache, adapter_kv_cache
        )
        if self.config.parallel_residual:
            n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
            x = x + h + self.mlp(n_2)
        else:
            if self.config.shared_attention_norm:
                raise NotImplementedError(
                    "No checkpoint amongst the ones we support uses this configuration"
                    " (non-parallel residual and shared attention norm)."
                )
            x = x + h
            x = x + self.mlp(self.norm_2(x))
        return x, new_kv_cache, new_adapter_kv_cache


class CausalSelfAttention(BaseCausalSelfAttention):
    """A modification of `lit_gpt.model.CausalSelfAttention` that adds the attention
    over the adaption prompt."""

    def __init__(self, config: Config, block_idx: int) -> None:
        super().__init__(config)
        if block_idx >= config.adapter_start_layer:
            # adapter embedding layer
            self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
            # gate for adaption
            self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
            self.reset_parameters()
        self.block_idx = block_idx

    def forward(
        self,
        x: torch.Tensor,
        rope: RoPECache,
        max_seq_length: int,
        mask: Optional[torch.Tensor] = None,
        input_pos: Optional[torch.Tensor] = None,
        kv_cache: Optional[KVCache] = None,
        adapter_kv_cache: Optional[KVCache] = None,
    ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]:
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

        qkv = self.attn(x)

        # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
        q_per_kv = self.config.n_head // self.config.n_query_groups
        total_qkv = q_per_kv + 2  # each group has 1+ queries, 1 key, and 1 value
        qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)
        qkv = qkv.permute(0, 2, 3, 1, 4)  # (B, n_query_groups, total_qkv, T, hs)

        # split batched computation into three
        q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)

        # repeat k and v if necessary
        if self.config.n_query_groups != 1:  # doing this would require a full kv cache with MQA (inefficient!)
            # for MHA this is a no-op
            k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
            v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)

        q = q.reshape(B, -1, T, self.config.head_size)  # (B, nh_q, T, hs)
        k = k.reshape(B, -1, T, self.config.head_size)  # (B, nh_k, T, hs)
        v = v.reshape(B, -1, T, self.config.head_size)  # (B, nh_v, T, hs)

        n_elem = int(self.config.rotary_percentage * self.config.head_size)

        cos, sin = rope
        q_roped = apply_rope(q[..., :n_elem], cos, sin)
        k_roped = apply_rope(k[..., :n_elem], cos, sin)
        q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
        k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)

        if kv_cache is not None:
            cache_k, cache_v = kv_cache
            cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype)
            # check if reached token limit
            if input_pos[-1] >= max_seq_length:
                input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
                # shift 1 position to the left
                cache_k = torch.roll(cache_k, -1, dims=2)
                cache_v = torch.roll(cache_v, -1, dims=2)
            k = cache_k.index_copy_(2, input_pos, k)
            v = cache_v.index_copy_(2, input_pos, v)
            kv_cache = k, v

        y = self.scaled_dot_product_attention(q, k, v, mask=mask)

        if self.block_idx >= self.config.adapter_start_layer:
            aT = self.config.adapter_prompt_length
            if adapter_kv_cache is not None:
                ak, av = adapter_kv_cache
            else:
                prefix = self.adapter_wte.weight.reshape(1, aT, C)
                aqkv = self.attn(prefix)
                aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
                aqkv = aqkv.permute(0, 2, 3, 1, 4)
                _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
                if self.config.n_query_groups != 1:
                    # for MHA this is a no-op
                    ak = ak.repeat_interleave(q_per_kv, dim=2)
                    av = av.repeat_interleave(q_per_kv, dim=2)
                ak = ak.view(1, -1, aT, self.config.head_size)  # (1, nh_ak, aT, hs)
                av = av.view(1, -1, aT, self.config.head_size)  # (1, nh_av, aT, hs)
                adapter_kv_cache = (ak, av)

            amask = torch.ones(T, aT, dtype=torch.bool, device=x.device)
            ay = self.scaled_dot_product_attention(q, ak, av, amask)
            y = y + self.gating_factor * ay

        y = y.reshape(B, T, C)  # re-assemble all head outputs side by side

        # output projection
        y = self.proj(y)

        return y, kv_cache, adapter_kv_cache

    def reset_parameters(self) -> None:
        torch.nn.init.zeros_(self.gating_factor)

    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
        """For compatibility with older checkpoints."""
        if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
            state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


def mark_only_adapter_as_trainable(model: GPT) -> None:
    """Sets `requires_grad=False` for all non-adapter weights."""
    for name, param in model.named_parameters():
        param.requires_grad = adapter_filter(name, param)


def adapter_filter(key: str, value: Any) -> bool:
    return "adapter_wte" in key or "gating_factor" in key


================================================
FILE: lit_gpt/adapter_v2.py
================================================
"""Implementation of the paper:

LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model
https://arxiv.org/abs/2304.15010

Port for Lit-GPT
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
import torch.nn as nn
from typing_extensions import Self

import lit_gpt
from lit_gpt.adapter import GPT as BaseModel
from lit_gpt.adapter import Block as BaseBlock
from lit_gpt.adapter import Config as BaseConfig
from lit_gpt.adapter import KVCache, RoPECache
from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention
from lit_gpt.model import apply_rope
from lit_gpt.utils import map_old_state_dict_weights


@dataclass
class Config(BaseConfig):
    @property
    def mlp_class(self) -> Type:
        return getattr(lit_gpt.adapter_v2, self._mlp_class)


def adapter_filter(key: str, value: Any) -> bool:
    adapter_substrings = (
        # regular adapter v1 parameters
        "adapter_wte",
        "gating_factor",
        # adapter v2: new bias and scale used in Linear
        "adapter_scale",
        "adapter_bias",
        # adapter v2: Norm parameters are now trainable
        "norm_1",
        "norm_2",
        "ln_f",
    )
    return any(s in key for s in adapter_substrings)


class AdapterV2Linear(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int, **kwargs) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
        self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False)
        self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False)
        self.reset_parameters()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.adapter_scale * (self.linear(x) + self.adapter_bias)

    def reset_parameters(self) -> None:
        nn.init.zeros_(self.adapter_bias)
        nn.init.ones_(self.adapter_scale)


class GPT(BaseModel):
    def __init__(self, config: Config) -> None:
        # Skip the parent class __init__ altogether and replace it to avoid useless allocations
        nn.Module.__init__(self)
        assert config.padded_vocab_size is not None
        self.config = config

        self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=False)
        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
                h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
                ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
            )
        )

        self.rope_cache: Optional[RoPECache] = None
        self.mask_cache: Optional[torch.Tensor] = None
        self.kv_caches: List[KVCache] = []
        self.adapter_kv_caches: List[KVCache] = []

    @classmethod
    def from_name(cls, name: str, **kwargs: Any) -> Self:
        return cls(Config.from_name(name, **kwargs))

    def _init_weights(self, module: nn.Module) -> None:
        """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
        super()._init_weights(module)
        if isinstance(module, CausalSelfAttention):
            module.reset_parameters()
        if isinstance(module, AdapterV2Linear):
            module.reset_parameters()

    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
        """For compatibility with base checkpoints."""
        mapping = {"lm_head.weight": "lm_head.linear.weight"}
        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


class Block(BaseBlock):
    """The implementation is identical to `lit_gpt.model.Block` with the exception that
    we replace the attention layer where adaption is implemented."""

    def __init__(self, config: Config, block_idx: int) -> None:
        # Skip the parent class __init__ altogether and replace it to avoid useless allocations
        nn.Module.__init__(self)
        self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
        self.attn = CausalSelfAttention(config, block_idx)
        if not config.shared_attention_norm:
            self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
        self.mlp = config.mlp_class(config)

        self.config = config


class CausalSelfAttention(BaseCausalSelfAttention):
    def __init__(self, config: Config, block_idx: int) -> None:
        """Causal self-attention with calculating qkv matrices with a single matrix* and Low Ranking Adaptation for
        parameter-efficient fine-tuning.

        *Instead of creating multiple heads and concatenating the result (in addition to creating separate matrices for
        query, key and value for each head) we can do this in a single pass with a single weight matrix.
        """
        # Skip the parent class __init__ altogether and replace it to avoid useless allocations
        nn.Module.__init__(self)
        shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
        # key, query, value projections for all heads, but in a batch
        self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias)
        # output projection
        self.proj = AdapterV2Linear(config.n_embd, config.n_embd, bias=config.bias)
        if block_idx >= config.adapter_start_layer:
            # adapter embedding layer
            self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
            # gate for adaption
            self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
            self.reset_parameters()
        self.block_idx = block_idx

        self.config = config

    def forward(
        self,
        x: torch.Tensor,
        rope: RoPECache,
        max_seq_length: int,
        mask: Optional[torch.Tensor] = None,
        input_pos: Optional[torch.Tensor] = None,
        kv_cache: Optional[KVCache] = None,
        adapter_kv_cache: Optional[KVCache] = None,
    ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]:
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

        qkv = self.attn(x)

        # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
        q_per_kv = self.config.n_head // self.config.n_query_groups
        total_qkv = q_per_kv + 2  # each group has 1+ queries, 1 key, and 1 value
        qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)
        qkv = qkv.permute(0, 2, 3, 1, 4)  # (B, n_query_groups, total_qkv, T, hs)

        # split batched computation into three
        q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)

        # repeat k and v if necessary
        if self.config.n_query_groups != 1:  # doing this would require a full kv cache with MQA (inefficient!)
            # for MHA this is a no-op
            k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
            v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)

        q = q.reshape(B, -1, T, self.config.head_size)  # (B, nh_q, T, hs)
        k = k.reshape(B, -1, T, self.config.head_size)  # (B, nh_k, T, hs)
        v = v.reshape(B, -1, T, self.config.head_size)  # (B, nh_v, T, hs)

        n_elem = int(self.config.rotary_percentage * self.config.head_size)

        cos, sin = rope
        q_roped = apply_rope(q[..., :n_elem], cos, sin)
        k_roped = apply_rope(k[..., :n_elem], cos, sin)
        q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
        k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)

        if kv_cache is not None:
            cache_k, cache_v = kv_cache
            cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype)
            # check if reached token limit
            if input_pos[-1] >= max_seq_length:
                input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
                # shift 1 position to the left
                cache_k = torch.roll(cache_k, -1, dims=2)
                cache_v = torch.roll(cache_v, -1, dims=2)
            k = cache_k.index_copy_(2, input_pos, k)
            v = cache_v.index_copy_(2, input_pos, v)
            kv_cache = k, v

        y = self.scaled_dot_product_attention(q, k, v, mask=mask)

        if self.block_idx >= self.config.adapter_start_layer:
            aT = self.config.adapter_prompt_length
            if adapter_kv_cache is not None:
                ak, av = adapter_kv_cache
            else:
                prefix = self.adapter_wte.weight.reshape(1, aT, C)
                aqkv = self.attn(prefix)
                aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
                aqkv = aqkv.permute(0, 2, 3, 1, 4)
                _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
                if self.config.n_query_groups != 1:
                    # for MHA this is a no-op
                    ak = ak.repeat_interleave(q_per_kv, dim=2)
                    av = av.repeat_interleave(q_per_kv, dim=2)
                ak = ak.view(1, -1, aT, self.config.head_size)  # (1, nh_ak, aT, hs)
                av = av.view(1, -1, aT, self.config.head_size)  # (1, nh_av, aT, hs)
                adapter_kv_cache = (ak, av)

            amask = torch.ones(T, aT, dtype=torch.bool, device=x.device)
            ay = self.scaled_dot_product_attention(q, ak, av, amask)
            y = y + self.gating_factor * ay

        y = y.reshape(B, T, C)  # re-assemble all head outputs side by side

        # output projection
        y = self.proj(y)

        return y, kv_cache, adapter_kv_cache

    def reset_parameters(self) -> None:
        torch.nn.init.zeros_(self.gating_factor)

    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
        """For compatibility with base checkpoints."""
        mapping = {
            "attn.weight": "attn.linear.weight",
            "attn.bias": "attn.linear.bias",
            "proj.weight": "proj.linear.weight",
            "proj.bias": "proj.linear.bias",
        }
        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
        # For compatibility with older checkpoints
        if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
            state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


class GptNeoxMLP(lit_gpt.model.GptNeoxMLP):
    def __init__(self, config: Config) -> None:
        nn.Module.__init__(self)
        self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
        self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)

    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
        """For compatibility with base checkpoints."""
        mapping = {
            "fc.weight": "fc.linear.weight",
            "fc.bias": "fc.linear.bias",
            "proj.weight": "proj.linear.weight",
            "proj.bias": "proj.linear.bias",
        }
        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


class LLaMAMLP(lit_gpt.model.LLaMAMLP):
    def __init__(self, config: Config) -> None:
        nn.Module.__init__(self)
        self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
        self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
        self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)

    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
        """For compatibility with base checkpoints."""
        mapping = {
            "fc_1.weight": "fc_1.linear.weight",
            "fc_1.bias": "fc_1.linear.bias",
            "fc_2.weight": "fc_2.linear.weight",
            "fc_2.bias": "fc_2.linear.bias",
            "proj.weight": "proj.linear.weight",
            "proj.bias": "proj.linear.bias",
        }
        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


def mark_only_adapter_v2_as_trainable(model: GPT) -> None:
    """Sets requires_grad=False for all non-adapter weights"""
    for name, param in model.named_parameters():
        param.requires_grad = adapter_filter(name, param)


================================================
FILE: lit_gpt/config.py
================================================
from dataclasses import dataclass
from typing import Any, Literal, Optional, Type

import torch
from typing_extensions import Self

import lit_gpt.model
from lit_gpt.utils import find_multiple


@dataclass
class Config:
    org: str = "Lightning-AI"
    name: str = "lit-GPT"
    block_size: int = 4096
    vocab_size: int = 50254
    padding_multiple: int = 512
    padded_vocab_size: Optional[int] = None
    n_layer: int = 16
    n_head: int = 32
    n_embd: int = 4096
    rotary_percentage: float = 0.25
    parallel_residual: bool = True
    bias: bool = True
    # to use multi-head attention (MHA), set this to `n_head` (default)
    # to use multi-query attention (MQA), set this to 1
    # to use grouped-query attention (GQA), set this to a value in between
    # Example with `n_head=4`
    # ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
    # │ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │
    # └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
    #   │    │    │    │         │        │                 │
    # ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐
    # │ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │
    # └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘
    #   │    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐
    # ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐
    # │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │
    # └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘
    # ◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶
    #         MHA                    GQA                   MQA
    #   n_query_groups=4       n_query_groups=2      n_query_groups=1
    #
    # credit https://arxiv.org/pdf/2305.13245.pdf
    n_query_groups: Optional[int] = None
    shared_attention_norm: bool = False
    _norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
    norm_eps: float = 1e-5
    _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP"
    intermediate_size: Optional[int] = None
    condense_ratio: int = 1

    def __post_init__(self):
        # error checking
        assert self.n_embd % self.n_head == 0
        # vocab size should be a power of 2 to be optimal on hardware. compute the closest value
        if self.padded_vocab_size is None:
            self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple)
        # compute the number of query groups
        if self.n_query_groups is not None:
            assert self.n_head % self.n_query_groups == 0
        else:
            self.n_query_groups = self.n_head
        # compute the intermediate size for MLP if not set
        if self.intermediate_size is None:
            if self._mlp_class == "LLaMAMLP":
                raise ValueError("The config needs to set the `intermediate_size`")
            self.intermediate_size = 4 * self.n_embd

    @property
    def head_size(self) -> int:
        return self.n_embd // self.n_head

    @classmethod
    def from_name(cls, name: str, **kwargs: Any) -> Self:
        conf_dict = name_to_config[name].copy()
        conf_dict.update(kwargs)
        return cls(**conf_dict)

    @property
    def mlp_class(self) -> Type:
        # `self._mlp_class` cannot be the type to keep the config json serializable
        return getattr(lit_gpt.model, self._mlp_class)

    @property
    def norm_class(self) -> Type:
        # `self._norm_class` cannot be the type to keep the config json serializable
        if self._norm_class == "RMSNorm":
            from lit_gpt.rmsnorm import RMSNorm

            return RMSNorm
        elif self._norm_class == "FusedRMSNorm":
            from lit_gpt.rmsnorm import FusedRMSNorm
            return FusedRMSNorm
        return getattr(torch.nn, self._norm_class)


########################
# Stability AI StableLM
########################
configs = [
    # https://huggingface.co/stabilityai/stablelm-base-alpha-3b/blob/main/config.json
    dict(org="stabilityai", name="stablelm-base-alpha-3b", padding_multiple=512),
    # https://huggingface.co/stabilityai/stablelm-base-alpha-7b/blob/main/config.json
    dict(org="stabilityai", name="stablelm-base-alpha-7b", n_head=48, n_embd=6144, padding_multiple=256),
    # https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b/blob/main/config.json
    dict(org="stabilityai", name="stablelm-tuned-alpha-3b", n_head=32, padding_multiple=512),
    # https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b/blob/main/config.json
    dict(org="stabilityai", name="stablelm-tuned-alpha-7b", n_head=48, n_embd=6144, padding_multiple=256),
]

####################
# EleutherAI Pythia
####################
pythia = [
    # https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json
    dict(org="EleutherAI", name="pythia-70m", block_size=2048, n_layer=6, n_embd=512, n_head=8, padding_multiple=128),
    # https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json
    dict(
        org="EleutherAI", name="pythia-160m", block_size=2048, n_layer=12, n_embd=768, n_head=12, padding_multiple=128
    ),
    # https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json
    dict(
        org="EleutherAI", name="pythia-410m", block_size=2048, n_layer=24, n_embd=1024, n_head=16, padding_multiple=128
    ),
    # https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json
    dict(org="EleutherAI", name="pythia-1b", block_size=2048, n_layer=16, n_embd=2048, n_head=8, padding_multiple=128),
    # https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json
    dict(
        org="EleutherAI", name="pythia-1.4b", block_size=2048, n_layer=24, n_embd=2048, n_head=16, padding_multiple=128
    ),
    # https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json
    dict(
        org="EleutherAI", name="pythia-2.8b", block_size=2048, n_layer=32, n_embd=2560, n_head=32, padding_multiple=128
    ),
    # https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json
    dict(
        org="EleutherAI", name="pythia-6.9b", block_size=2048, n_layer=32, n_embd=4096, n_head=32, padding_multiple=256
    ),
    # https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json
    dict(
        org="EleutherAI", name="pythia-12b", block_size=2048, n_layer=36, n_embd=5120, n_head=40, padding_multiple=512
    ),
]
configs.extend(pythia)
for c in pythia:
    copy = c.copy()
    copy["name"] = f"{c['name']}-deduped"
    configs.append(copy)


####################################
# togethercomputer RedPajama INCITE
####################################
redpajama_incite = [
    # https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1/blob/main/config.json
    dict(
        org="togethercomputer",
        name="RedPajama-INCITE-{}-3B-v1",
        block_size=2048,
        n_layer=32,
        n_embd=2560,
        n_head=32,
        padding_multiple=256,
        rotary_percentage=1.0,
        parallel_residual=False,
    ),
    # https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Base/blob/main/config.json
    dict(
        org="togethercomputer",
        name="RedPajama-INCITE-7B-{}",
        block_size=2048,
        n_layer=32,
        n_embd=4096,
        n_head=32,
        padding_multiple=256,
        rotary_percentage=1.0,
        parallel_residual=False,
    ),
    # this redirects to the checkpoint above. kept for those who had the old weights already downloaded
    dict(
        org="togethercomputer",
        name="RedPajama-INCITE-{}-7B-v0.1",
        block_size=2048,
        n_layer=32,
        n_embd=4096,
        n_head=32,
        padding_multiple=256,
        rotary_percentage=1.0,
        parallel_residual=False,
    ),
]
for c in redpajama_incite:
    for kind in ("Base", "Chat", "Instruct"):
        copy = c.copy()
        copy["name"] = c["name"].format(kind)
        configs.append(copy)


#################
# TII UAE Falcon
#################
falcon = [
    # https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json
    dict(
        org="tiiuae",
        name="falcon-7b{}",
        block_size=2048,
        padded_vocab_size=65024,
        n_layer=32,
        n_head=71,
        n_embd=4544,
        rotary_percentage=1.0,
        parallel_residual=True,
        n_query_groups=1,
        bias=False,
        # this is not in the config, but in the original model implementation, only for this config
        shared_attention_norm=True,
    ),
    # https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json
    dict(
        org="tiiuae",
        name="falcon-40b{}",
        block_size=2048,
        padded_vocab_size=65024,
        n_layer=60,
        n_head=128,
        n_embd=8192,
        rotary_percentage=1.0,
        parallel_residual=True,
        n_query_groups=8,
        bias=False,
    ),
]
for c in falcon:
    for kind in ("", "-instruct"):
        copy = c.copy()
        copy["name"] = c["name"].format(kind)
        configs.append(copy)


#############################
# StatNLP Research
#############################
tiny_LLaMA = [
     
    # https://twitter.com/cwolferesearch/status/1691929174175264858
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_1b",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5, #Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
    ),
    dict(
        org="StatNLP-research",
        name="tiny_LLaMA_120M",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=12,
        n_head=12,
        n_embd=768,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=2048,
        n_query_groups=1,
    ),
    dict(
        org="StatNLP-research",
        name="code_tiny_LLaMA_1b",
        block_size=8192,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=22,
        n_head=32,
        n_embd=2048,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="FusedRMSNorm",
        norm_eps=1e-5, #Llama 2 use 1e-5. Llama 1 use 1e-6
        _mlp_class="LLaMAMLP",
        intermediate_size=5632,
        n_query_groups=4,
        condense_ratio= 4
    ),
]
configs.extend(tiny_LLaMA)


#############################
# OpenLM Research Open LLaMA
#############################
open_LLaMA = [
    # https://huggingface.co/openlm-research/open_llama_3b/blob/main/config.json
    dict(
        org="openlm-research",
        name="open_llama_3b",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=26,
        n_head=32,
        n_embd=3200,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=8640,
    ),
    # https://huggingface.co/openlm-research/open_llama_7b/blob/main/config.json
    dict(
        org="openlm-research",
        name="open_llama_7b",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=32,
        n_head=32,
        n_embd=4096,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=11008,
    ),
    # https://huggingface.co/openlm-research/open_llama_13b/blob/main/config.json
    dict(
        org="openlm-research",
        name="open_llama_13b",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=40,
        n_head=40,
        n_embd=5120,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=13824,
    ),
]
configs.extend(open_LLaMA)


###############
# LMSYS Vicuna
###############
vicuna = [
    # https://huggingface.co/lmsys/vicuna-7b-v1.3/blob/main/config.json
    dict(
        org="lmsys",
        name="vicuna-7b-v1.3",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=32,
        n_head=32,
        n_embd=4096,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=11008,
    ),
    # https://huggingface.co/lmsys/vicuna-13b-v1.3/blob/main/config.json
    dict(
        org="lmsys",
        name="vicuna-13b-v1.3",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=40,
        n_head=40,
        n_embd=5120,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=13824,
    ),
    # https://huggingface.co/lmsys/vicuna-33b-v1.3/blob/main/config.json
    dict(
        org="lmsys",
        name="vicuna-33b-v1.3",
        block_size=2048,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=60,
        n_head=52,
        n_embd=6656,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=17920,
    ),
    dict(
        org="lmsys",
        name="vicuna-7b-v1.5",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=32,
        n_head=32,
        n_embd=4096,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=11008,
    ),
    dict(
        org="lmsys",
        name="vicuna-7b-v1.5-16k",
        block_size=16384,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=32,
        n_head=32,
        n_embd=4096,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=11008,
        condense_ratio=4,
    ),
    dict(
        org="lmsys",
        name="vicuna-13b-v1.5",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=40,
        n_head=40,
        n_embd=5120,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=13824,
    ),
    dict(
        org="lmsys",
        name="vicuna-13b-v1.5-16k",
        block_size=16384,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=40,
        n_head=40,
        n_embd=5120,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=13824,
        condense_ratio=4,
    ),
]
configs.extend(vicuna)


#################
# LMSYS LongChat
#################
long_chat = [
    # https://huggingface.co/lmsys/longchat-7b-16k/blob/main/config.json
    dict(
        org="lmsys",
        name="longchat-7b-16k",
        block_size=16384,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=32,
        n_head=32,
        n_embd=4096,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=11008,
        condense_ratio=8,
    ),
    # https://huggingface.co/lmsys/longchat-13b-16k/blob/main/config.json
    dict(
        org="lmsys",
        name="longchat-13b-16k",
        block_size=16384,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=40,
        n_head=40,
        n_embd=5120,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=13824,
        condense_ratio=8,
    ),
]
configs.extend(long_chat)


######################
# NousResearch Hermes
######################
nous_research = [
    # https://huggingface.co/NousResearch/Nous-Hermes-13B/blob/main/config.json
    dict(
        org="NousResearch",
        name="Nous-Hermes-13b",
        block_size=2048,
        padded_vocab_size=32001,
        n_layer=40,
        n_head=40,
        n_embd=5120,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-6,
        _mlp_class="LLaMAMLP",
        intermediate_size=13824,
    )
]
configs.extend(nous_research)


###############
# Meta LLaMA 2
###############
llama_2 = [
    # https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json
    dict(
        org="meta-llama",
        name="Llama-2-7b{}-hf",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=32,
        n_head=32,
        n_embd=4096,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=11008,
    ),
    dict(
        org="meta-llama",
        name="CodeLlama-2-7b-hf",
        block_size=4096,
        vocab_size=32016,
        padded_vocab_size=32016,
        padding_multiple=64,
        n_layer=32,
        n_head=32,
        n_embd=4096,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=11008,
    ),
    # https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json
    dict(
        org="meta-llama",
        name="Llama-2-13b{}-hf",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=40,
        n_head=40,
        n_embd=5120,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=13824,
    ),
    # https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json
    dict(
        org="meta-llama",
        name="Llama-2-70b{}-hf",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=80,
        n_head=64,
        n_embd=8192,
        n_query_groups=8,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=28672,
    ),
]
for c in llama_2:
    for kind in ("", "-chat"):
        copy = c.copy()
        copy["name"] = c["name"].format(kind)
        configs.append(copy)


##########################
# Stability AI FreeWilly2
##########################
freewilly_2 = [
    # https://huggingface.co/stabilityai/FreeWilly2/blob/main/config.json
    dict(
        org="stabilityai",
        name="FreeWilly2",
        block_size=4096,
        vocab_size=32000,
        padding_multiple=64,
        n_layer=80,
        n_head=64,
        n_embd=8192,
        n_query_groups=8,
        rotary_percentage=1.0,
        parallel_residual=False,
        bias=False,
        _norm_class="RMSNorm",
        norm_eps=1e-5,
        _mlp_class="LLaMAMLP",
        intermediate_size=28672,
    )
]
configs.extend(freewilly_2)


name_to_config = {config["name"]: config for config in configs}


================================================
FILE: lit_gpt/fused_cross_entropy.py
================================================
# Copyright (c) 2023, Tri Dao.

import torch
import torch.nn as nn
import xentropy_cuda_lib

# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
    torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base


class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        logits,
        labels,
        smoothing=0.0,
        ignored_index=-100,
        inplace_backward=False,
        process_group=None,
    ):
        """
        logits: (batch, vocab_size)
        labels: (batch,)
        If process_group is not None, we're doing Tensor Parallel: each process is responsible for
        one part of the vocab. The loss needs to be aggregated across processes.
        """
        batch, vocab_size = logits.shape
        assert labels.shape == (batch,)
        world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
        ctx.total_classes = world_size * vocab_size

        if world_size == 1:
            losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
            losses.masked_fill_(labels == ignored_index, 0)
            labels_local = labels
        else:
            rank = torch.distributed.get_rank(process_group)
            vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size

            # Create a mask of valid vocab ids (1 means it needs to be masked).
            labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
            ignored_mask = labels == ignored_index
            labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)

            # For tensor parallel cross entropy with smoothing, we want to pass in the total number
            # of classes so that smoothing can be applied correctly. If total_classes=-1, use the
            # last dimension of the input tensor.
            losses, lse_local = xentropy_cuda_lib.forward(
                logits, labels_local, smoothing, world_size * vocab_size
            )
            assert lse_local.shape == (batch,)
            assert losses.shape == (batch,)
            losses.masked_fill_(ignored_mask, 0)
            # For labels == ignored_index, the loss is always 0.
            # If there's no smoothing, if labels are in the vocab of this partition, losses contains
            # lse_local - predicted logit, and 0 otherwise.
            # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
            # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
            # For labels not in the vocab of this partition, losses contains
            # 0.1 * (lse_local - sum logit / total_classes).

            lse_allgather = torch.empty(
                world_size, batch, dtype=lse_local.dtype, device=lse_local.device
            )
            torch.distributed.all_gather_into_tensor(
                lse_allgather, lse_local.contiguous(), group=process_group
            )
            handle_losses = torch.distributed.all_reduce(
                losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
            )
            lse = torch.logsumexp(lse_allgather, dim=0)
            # If there's no smoothing, the total losses are lse_local - predicted_logit,
            # we just have to subtract the lse_local and add the lse (global).
            # If there's smoothing=0.1, the total losses are
            # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
            # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
            rank_per_sample = torch.div(labels, vocab_size, rounding_mode="floor")
            lse_local = lse_allgather[
                rank_per_sample, torch.arange(batch, device=lse_allgather.device)
            ]

            handle_losses.wait()
            if smoothing == 0.0:
                losses += lse - lse_local
            else:
                losses += (1 - smoothing) * (lse - lse_local) + smoothing * (
                    lse - lse_allgather.sum(dim=0)
                )
            losses.masked_fill_(ignored_mask, 0)

        ctx.save_for_backward(logits, lse, labels_local)
        ctx.smoothing = smoothing
        ctx.ignored_index = ignored_index
        ctx.inplace_backward = inplace_backward
        return losses

    @staticmethod
    def backward(ctx, grad_loss):
        logits, lse, labels = ctx.saved_tensors
        grad_loss = grad_loss.contiguous()
        grad_loss.masked_fill_(labels == ctx.ignored_index, 0)
        grad_logits = xentropy_cuda_lib.backward(
            grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes
        )
        return grad_logits, None, None, None, None, None, None


class FusedCrossEntropyLoss(nn.Module):
    def __init__(
        self,
        ignore_index=-100,
        reduction="mean",
        label_smoothing=0.0,
        inplace_backward=True,
        process_group=None,
    ):
        super().__init__()
        if reduction not in ["mean", "none"]:
            raise NotImplementedError("Only support reduction = 'mean' or 'none'")
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.label_smoothing = label_smoothing
        self.inplace_backward = inplace_backward
        self.process_group = process_group

    def forward(self, input, target):
        assert input.is_cuda and target.is_cuda
        # SoftmaxCrossEntropyLoss implicitly casts to float
        if len(input.shape) == 3:
            input = input.view(-1, input.size(-1))
            target = target.view(-1)
        loss = SoftmaxCrossEntropyLossFn.apply(
            input,
            target,
            self.label_smoothing,
            self.ignore_index,
            self.inplace_backward,
            self.process_group,
        )
        if self.reduction == "mean":
            return loss.sum() / (target != self.ignore_index).sum()
        else:
            return loss

================================================
FILE: lit_gpt/fused_rotary_embedding.py
================================================
# Copyright (c) 2023, Tri Dao.

import math
from typing import Optional, Tuple

import rotary_emb
import torch
from einops import rearrange, repeat

class ApplyRotaryEmb(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
        """
            x: (batch_size, seqlen, nheads, headdim)
            cos, sin: (seqlen, rotary_dim / 2)
            interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
                of 1st half and 2nd half (GPT-NeoX style).
        rotary_dim must be <= headdim
        Apply rotary embedding to the first rotary_dim of x.
        """
        batch, seqlen, nheads, headdim = x.shape
        rotary_seqlen, rotary_dim = cos.shape
        rotary_dim *= 2
        assert rotary_dim <= headdim
        assert seqlen <= rotary_seqlen
        assert sin.shape == (rotary_seqlen, rotary_dim // 2)
        x_ro = x[..., :rotary_dim]
        x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
        out = torch.empty_like(x) if not inplace else x
        out_ro = out[..., :rotary_dim]
        if inplace:
            o1, o2 = x1, x2
        else:
            o1, o2 = (
                out_ro.chunk(2, dim=-1)
                if not interleaved
                else (out_ro[..., ::2], out_ro[..., 1::2])
            )
        rotary_emb.apply_rotary(
            x1,
            x2,
            rearrange(cos[:seqlen], "s d -> s 1 d"),
            rearrange(sin[:seqlen], "s d -> s 1 d"),
            o1,
            o2,
            False,
        )
        if not inplace and rotary_dim < headdim:
            out[..., rotary_dim:].copy_(x[..., rotary_dim:])
        ctx.save_for_backward(cos, sin)
        ctx.interleaved = interleaved
        ctx.inplace = inplace
        return out if not inplace else x

    @staticmethod
    def backward(ctx, do):
        cos, sin = ctx.saved_tensors
        _, seqlen, _, headdim = do.shape
        rotary_dim = cos.shape[-1]
        rotary_dim *= 2
        inplace = ctx.inplace
        do_ro = do[..., :rotary_dim]
        do1, do2 = (
            do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2])
        )
        dx = torch.empty_like(do) if not inplace else do
        if inplace:
            dx1, dx2 = do1, do2
        else:
            dx_ro = dx[..., :rotary_dim]
            dx1, dx2 = (
                dx_ro.chunk(2, dim=-1)
                if not ctx.interleaved
                else (dx_ro[..., ::2], dx_ro[..., 1::2])
            )
        rotary_emb.apply_rotary(
            do1,
            do2,
            rearrange(cos[:seqlen], "s d -> s 1 d"),
            rearrange(sin[:seqlen], "s d -> s 1 d"),
            dx1,
            dx2,
            True,
        )
        if not inplace and rotary_dim < headdim:
            dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
        return dx, None, None, None, None


apply_rotary_emb_func = ApplyRotaryEmb.apply



================================================
FILE: lit_gpt/lora.py
================================================
# Derived from https://github.com/microsoft/LoRA
#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------

r"""
    Low Ranking Adaptation for LLMs scheme.

             ┌───────────────────┐
             ┆         h         ┆
             └───────────────────┘
                       ▲
                       |
                       +
                    /     \
    ┌─────────────────┐    ╭───────────────╮     Matrix initialization:
    ┆                 ┆     \      B      /      B = 0
    ┆   pretrained    ┆      \    r*d    /       A = N(0, sigma^2)
    ┆    weights      ┆       ╰─────────╯
    ┆                 ┆       |    r    |        r - rank
    ┆   W e R^(d*d)   ┆       | ◀─────▶ |
    ┆                 ┆       ╭─────────╮
    └─────────────────┘      /     A     \
              ▲             /     d*r     \
               \           ╰───────────────╯
                \                ▲
                 \              /
                  \            /
             ┌───────────────────┐
             ┆         x         ┆
             └───────────────────┘

With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d,
we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates
for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of
course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen
pretrained weights and thus fine-tune the model.

The goal of this approach is to move weight updates into a separate matrix which is decomposed with
two matrices of a lower rank.
"""

import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.nn as nn
from torch.nn import functional as F
from typing_extensions import Self

import lit_gpt
from lit_gpt.config import Config as BaseConfig
from lit_gpt.model import GPT as BaseModel
from lit_gpt.model import Block as BaseBlock
from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention
from lit_gpt.model import KVCache, RoPECache
from lit_gpt.utils import map_old_state_dict_weights


class LoRALayer(nn.Module):
    def __init__(self, r: int, lora_alpha: int, lora_dropout: float):
        """Store LoRA specific attributes in a class.

        Args:
            r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
                the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
            lora_alpha: alpha is needed for scaling updates as alpha/r
                "This scaling helps to reduce the need to retune hyperparameters when we vary r"
                https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
            lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
        """
        super().__init__()
        assert r >= 0
        self.r = r
        self.lora_alpha = lora_alpha
        # Optional dropout
        if lora_dropout > 0.0:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        # Mark the weight as unmerged
        self.merged = False


class LoRALinear(LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self,
        # ↓ this part is for pretrained weights
        in_features: int,
        out_features: int,
        # ↓ the remaining part is for LoRA
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        **kwargs,
    ):
        """LoRA wrapper around linear class.

        This class has three weight matrices:
            1. Pretrained weights are stored as `self.linear.weight`
            2. LoRA A matrix as `self.lora_A`
            3. LoRA B matrix as `self.lora_B`
        Only LoRA's A and B matrices are updated, pretrained weights stay frozen.

        Args:
            in_features: number of input features of the pretrained weights
            out_features: number of output features of the pretrained weights
            r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
                the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
            lora_alpha: alpha is needed for scaling updates as alpha/r
                "This scaling helps to reduce the need to retune hyperparameters when we vary r"
                https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
            lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
        """
        super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
        self.linear = torch.nn.Linear(in_features, out_features, **kwargs)

        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.linear.weight.new_zeros((r, in_features)))
            self.lora_B = nn.Parameter(self.linear.weight.new_zeros((out_features, r)))
            self.scaling = self.lora_alpha / self.r
            self.reset_parameters()

    def reset_parameters(self):
        """Reset all the weights, even including pretrained ones."""
        if hasattr(self, "lora_A"):
            # initialize A the same way as the default for nn.Linear and B to zero
            # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def merge(self):
        """Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
        if self.r > 0 and not self.merged:
            # Merge the weights and mark it
            self.linear.weight.data += (self.lora_B @ self.lora_A) * self.scaling
            self.merged = True

    def forward(self, x: torch.Tensor):
        # if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass;
        # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights
        pretrained = self.linear(x)
        if self.r == 0 or self.merged:
            return pretrained
        lora = (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
        return pretrained + lora


class LoRAQKVLinear(LoRALinear):
    # LoRA implemented in a dense layer
    def __init__(
        self,
        # ↓ this part is for pretrained weights
        in_features: int,
        out_features: int,
        # ↓ the remaining part is for LoRA
        n_head: int,
        n_query_groups: int,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        enable_lora: Union[bool, Tuple[bool, bool, bool]] = False,
        **kwargs,
    ):
        """LoRA wrapper around linear class that is used for calculation of q, k and v matrices.

        This class has three weight matrices:
            1. Pretrained weights are stored as `self.linear.weight`
            2. LoRA A matrix as `self.lora_A`
            3. LoRA B matrix as `self.lora_B`
        Only LoRA's A and B matrices are updated, pretrained weights stay frozen.

        Args:
            in_features: number of input features of the pretrained weights
            out_features: number of output features of the pretrained weights
            n_head: number of attention heads
            n_query_groups: number of query groups (see diagram in `lit_gpt/config.py`)
            r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
                the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
            lora_alpha: alpha is needed for scaling updates as alpha/r
                "This scaling helps to reduce the need to retune hyperparameters when we vary r"
                https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
            lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
            enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we
                don't want to apply LoRA we can set it as False. For example if we want to apply LoRA only to `query`
                and `value` but keep `key` without weight updates we should pass `[True, False, True]`
        """
        super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
        self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
        self.n_head = n_head
        self.n_query_groups = n_query_groups
        if isinstance(enable_lora, bool):
            enable_lora = [enable_lora] * 3
        assert len(enable_lora) == 3
        self.enable_lora = enable_lora

        # Actual trainable parameters
        # To better understand initialization let's imagine that we have such parameters:
        # ⚬ in_features: 128 (embeddings_size)
        # ⚬ out_features: 384 (3 * embedding_size)
        # ⚬ r: 2
        # ⚬ enable_lora: [True, False, True]
        if r > 0 and any(enable_lora):
            self.lora_A = nn.Parameter(self.linear.weight.new_zeros((r * sum(enable_lora), in_features)))  # (4, 128)
            enable_q, enable_k, enable_v = enable_lora
            self.kv_embd_size = self.linear.in_features // (n_head // n_query_groups)
            # qkv_shapes will be used to split a tensor with weights correctly
            qkv_shapes = (
                self.linear.in_features * enable_q,
                self.kv_embd_size * enable_k,
                self.kv_embd_size * enable_v,
            )
            self.qkv_shapes = [s for s in qkv_shapes if s]
            self.lora_B = nn.Parameter(self.linear.weight.new_zeros(sum(self.qkv_shapes), r))  # (256, 2))
            # Notes about shapes above
            # - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;
            # 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in
            # F.linear function weights are automatically transposed. In addition conv1d requires channels to
            # be before seq length
            # - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is
            # 128*2; 2 tells to have two channels per group for group convolution

            # Scaling:
            # This balances the pretrained model`s knowledge and the new task-specific adaptation
            # https://lightning.ai/pages/community/tutorial/lora-llm/
            # So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set
            # alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can
            # tune these values to your needs. This value can be even slightly greater than 1.0!
            # https://github.com/cloneofsimo/lora
            self.scaling = self.lora_alpha / self.r

            # Compute the indices
            # Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values,
            # but not keys, then the weights update should be:
            #
            # [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
            #  [....................................],
            #  [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
            #      ↑              ↑            ↑
            # ________________________________________
            # | query         | key       | value    |
            # ----------------------------------------
            self.lora_ind = []
            if enable_q:
                self.lora_ind.extend(range(0, self.linear.in_features))
            if enable_k:
                self.lora_ind.extend(range(self.linear.in_features, self.linear.in_features + self.kv_embd_size))
            if enable_v:
                self.lora_ind.extend(range(self.linear.in_features + self.kv_embd_size, self.linear.out_features))
            self.reset_parameters()

    def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
        """Properly pad weight updates with zeros.

        If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys,
        then the weights update should be:

        [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
         [....................................],
         [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
            ↑              ↑            ↑
        ________________________________________
        | query         | key       | value    |
        ----------------------------------------

        Args:
            x: tensor with weights update that will be padded with zeros if necessary

        Returns:
            A tensor with weight updates and zeros for deselected q, k or v
        """
        # we need to do zero padding only if LoRA is disabled for one of QKV matrices
        if all(self.enable_lora):
            return x

        # Let's image that:
        # ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size)
        # ⚬ embeddings_size: 128
        # ⚬ self.linear.out_features: 384 (3 * embeddings_size)
        # ⚬ enable_lora: [True, False, True]
        # Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected
        # embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but
        # only for key updates (this is where self.lora_ind comes in handy)
        # Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors
        # for example when we want to merge/unmerge LoRA weights and pretrained weights
        x = x.transpose(0, 1)
        result = x.new_zeros((*x.shape[:-1], self.linear.out_features))  # (64, 64, 384)
        result = result.view(-1, self.linear.out_features)  # (4096, 384)
        result = result.index_copy(
            1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes))
        )  # (4096, 256)
        return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1)  # (64, 64, 384)

    def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
        """An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries.

        If the number of heads is equal to the number of query groups - grouped queries are disabled
        (see scheme in `lit_gpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized
        query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the
        input and weight matrices will be splitted in equally sized parts and applied separately (like having multiple
        conv layers side by side).

        Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually,
        apply each part of the weight matrix to the corresponding input's part and concatenate the result.

        Args:
            input: input matrix of shape (B, C, T)
            weight: weight matrix of shape (C_output, rank, 1).
                "C_output" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class).

        Returns:
            A tensor with a shape (B, C_output, T)

        """
        if self.n_head == self.n_query_groups:
            return F.conv1d(input, weight, groups=sum(self.enable_lora))  # (B, C_output, T)

        # Notation:
        # ⚬ N: number of enabled LoRA layers (self.enable_lora)
        # ⚬ C_output': embeddings size for each LoRA layer (not equal in size)
        # ⚬ r: rank of all LoRA layers (equal in size)

        input_splitted = input.chunk(sum(self.enable_lora), dim=1)  # N * (B, C // N, T)
        weight_splitted = weight.split(self.qkv_shapes)  # N * (C_output', r, 1)
        return torch.cat(
            [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1  # (B, C_output', T)
        )  # (B, C_output, T)

    def merge(self):
        """Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""

        # Let's assume that:
        # ⚬ self.linear.weight.data: (384, 128) or (3 * embedding_size, embedding_size)
        # ⚬ self.lora_A.data: (4, 128)
        # ⚬ self.lora_B.data: (256, 2)
        if self.r > 0 and any(self.enable_lora) and not self.merged:
            delta_w = self.conv1d(
                self.lora_A.data.unsqueeze(0),  # (4, 128) -> (1, 4, 128)
                self.lora_B.data.unsqueeze(-1),  # (256, 2) -> (256, 2, 1)
            ).squeeze(
                0
            )  # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
            # W = W + delta_W (merge)
            self.linear.weight.data += self.zero_pad(delta_w * self.scaling)  # (256, 128) after zero_pad (384, 128)
            self.merged = True

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Do the forward pass.

        If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication.
        If not, then multiply pretrained weights with input, apply LoRA on input and do summation.

        Args:
            x: input tensor of shape (batch_size, context_length, embedding_size)

        Returns:
            Output tensor of shape (batch_size, context_length, 3 * embedding_size)
        """

        # Let's assume that:
        # ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size)
        # ⚬ self.linear.weight: (384, 128) or (3 * embedding_size, embedding_size)
        # ⚬ self.lora_A.data: (4, 128)
        # ⚬ self.lora_B.data: (256, 2)

        # if weights are merged or LoRA is disabled (r <= 0 or all `enable_lora` are False) - it's only a regular nn.Linear forward pass;
        # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights
        pretrained = self.linear(x)
        if self.r == 0 or not any(self.enable_lora) or self.merged:
            return pretrained
        after_A = F.linear(self.lora_dropout(x), self.lora_A)  # (64, 64, 128) @ (4, 128) -> (64, 64, 4)
        # For F.conv1d:
        # ⚬ input: input tensor of shape (mini-batch, in_channels, iW)
        # ⚬ weight: filters of shape (out_channels, in_channels/groups, kW)
        after_B = self.conv1d(
            after_A.transpose(-2, -1),  # (64, 64, 4) -> (64, 4, 64)
            self.lora_B.unsqueeze(-1),  # (256, 2) -> (256, 2, 1)
        ).transpose(
            -2, -1
        )  # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
        lora = self.zero_pad(after_B) * self.scaling  # (64, 64, 256) after zero_pad (64, 64, 384)
        return pretrained + lora


def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
    """Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights.

    Args:
        model: model with LoRA layers
        bias:
            ``"none"``: all bias weights will be frozen,
            ``"lora_only"``: only bias weight for LoRA layers will be unfrozen,
            ``"all"``: all bias weights will be unfrozen.

    Raises:
        NotImplementedError: if `bias` not in ["none", "lora_only", "all"]
    """
    # freeze all layers except LoRA's
    for n, p in model.named_parameters():
        if "lora_" not in n:
            p.requires_grad = False

    # depending on the `bias` value unfreeze bias weights
    if bias == "none":
        return
    if bias == "all":
        for n, p in model.named_parameters():
            if "bias" in n:
                p.requires_grad = True
    elif bias == "lora_only":
        for m in model.modules():
            if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None:
                m.bias.requires_grad = True
    else:
        raise NotImplementedError


def lora_filter(key: str, value: Any) -> bool:
    return "lora_" in key


@dataclass
class Config(BaseConfig):
    """
    Args:
        r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
            the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
        alpha: alpha is needed for scaling updates as alpha/r
            "This scaling helps to reduce the need to retune hyperparameters when we vary r"
            https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
        dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
        to_*: either apply LoRA to the specified weights or not
    """

    r: int = 0
    alpha: int = 1
    dropout: float = 0.0
    to_query: bool = False
    to_key: bool = False
    to_value: bool = False
    to_projection: bool = False
    to_mlp: bool = False
    to_head: bool = False

    @property
    def mlp_class(self) -> Type:
        return getattr(lit_gpt.lora, self._mlp_class)


class GPT(BaseModel):
    def __init__(self, config: Config) -> None:
        nn.Module.__init__(self)
        assert config.padded_vocab_size is not None
        self.config = config

        self.lm_head = LoRALinear(
            config.n_embd,
            config.padded_vocab_size,
            bias=False,
            r=(config.r if config.to_head else 0),
            lora_alpha=config.alpha,
            lora_dropout=config.dropout,
        )

        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
                h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
                ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
            )
        )

        self.rope_cache: Optional[RoPECache] = None
        self.mask_cache: Optional[torch.Tensor] = None
        self.kv_caches: List[KVCache] = []

    def forward(
        self,
        idx: torch.Tensor,
        max_seq_length: Optional[int] = None,
        input_pos: Optional[torch.Tensor] = None,
        lm_head_chunk_size: int = 0,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        B, T = idx.size()
        use_kv_cache = input_pos is not None

        block_size = self.config.block_size
        if max_seq_length is None:
            max_seq_length = block_size
        if use_kv_cache:  # not relevant otherwise
            assert (
                max_seq_length >= T
            ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
        assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
        assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}"

        if self.rope_cache is None:
            self.rope_cache = self.build_rope_cache(idx)  # 2 * (block_size, head_size * rotary_percentage)
        # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask
        # for the kv-cache support (only during inference), we only create it in that situation
        # this will be resolved by https://github.com/pytorch/pytorch/issues/96099
        if use_kv_cache and self.mask_cache is None:
            self.mask_cache = self.build_mask_cache(idx)  # (1, 1, block_size, block_size)

        cos, sin = self.rope_cache
        if use_kv_cache:
            cos = cos.index_select(0, input_pos)
            sin = sin.index_select(0, input_pos)
            mask = self.mask_cache.index_select(2, input_pos)
            mask = mask[:, :, :, :max_seq_length]
        else:
            cos = cos[:T]
            sin = sin[:T]
            mask = None

        # forward the model itself
        x = self.transformer.wte(idx)  # token embeddings of shape (B, T, n_embd)

        if not use_kv_cache:
            for block in self.transformer.h:
                x, *_ = block(x, (cos, sin), max_seq_length)
        else:
            self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1))
            for i, block in enumerate(self.transformer.h):
                x, self.kv_caches[i] = block(x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i])

        x = self.transformer.ln_f(x)

        if lm_head_chunk_size > 0:
            # chunk the lm head logits to reduce the peak memory used by autograd
            return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]
        return self.lm_head(x)  # (B, T, vocab_size)

    @classmethod
    def from_name(cls, name: str, **kwargs: Any) -> Self:
        return cls(Config.from_name(name, **kwargs))

    def _init_weights(self, module: nn.Module) -> None:
        """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
        super()._init_weights(module)
        if isinstance(module, LoRALinear):
            module.reset_parameters()

    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
        """For compatibility with base checkpoints."""
        mapping = {"lm_head.weight": "lm_head.linear.weight"}
        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


class Block(BaseBlock):
    def __init__(self, config: Config) -> None:
        nn.Module.__init__(self)
        self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
        self.attn = CausalSelfAttention(config)
        if not config.shared_attention_norm:
            self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
        self.mlp = config.mlp_class(config)

        self.config = config


class CausalSelfAttention(BaseCausalSelfAttention):
    def __init__(self, config: Config) -> None:
        """Causal self-attention with calculating qkv matrices with a single matrix* and Low Ranking Adaptation for
        parameter-efficient fine-tuning.

        *Instead of creating multiple heads and concatenating the result (in addition to creating separate matrices for
        query, key and value for each head) we can do this in a single pass with a single weight matrix.
        """
        # Skip the parent class __init__ altogether and replace it to avoid
        # useless allocations
        nn.Module.__init__(self)
        shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
        # key, query, value projections for all heads, but in a batch
        self.attn = LoRAQKVLinear(
            in_features=config.n_embd,
            out_features=shape,
            r=config.r,
            lora_alpha=config.alpha,
            lora_dropout=config.dropout,
            enable_lora=(config.to_query, config.to_key, config.to_value),
            bias=config.bias,
            # for MQA/GQA support
            n_head=config.n_head,
            n_query_groups=config.n_query_groups,
        )
        # output projection
        self.proj = LoRALinear(
            config.n_embd,
            config.n_embd,
            bias=config.bias,
            r=(config.r if config.to_projection else 0),
            lora_alpha=config.alpha,
            lora_dropout=config.dropout,
        )

        self.config = config

    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
        """For compatibility with base checkpoints."""
        mapping = {
            "attn.weight": "attn.linear.weight",
            "attn.bias": "attn.linear.bias",
            "proj.weight": "proj.linear.weight",
            "proj.bias": "proj.linear.bias",
        }
        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


class GptNeoxMLP(lit_gpt.model.GptNeoxMLP):
    def __init__(self, config: Config) -> None:
        nn.Module.__init__(self)
        self.fc = LoRALinear(
            config.n_embd,
            config.intermediate_size,
            bias=config.bias,
            r=(config.r if config.to_mlp else 0),
            lora_alpha=config.alpha,
            lora_dropout=config.dropout,
        )
        self.proj = LoRALinear(
            config.intermediate_size,
            config.n_embd,
            bias=config.bias,
            r=(config.r if config.to_mlp else 0),
            lora_alpha=config.alpha,
            lora_dropout=config.dropout,
        )

    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
        """For compatibility with base checkpoints."""
        mapping = {
            "fc.weight": "fc.linear.weight",
            "fc.bias": "fc.linear.bias",
            "proj.weight": "proj.linear.weight",
            "proj.bias": "proj.linear.bias",
        }
        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


class LLaMAMLP(lit_gpt.model.LLaMAMLP):
    def __init__(self, config: Config) -> None:
        nn.Module.__init__(self)
        self.fc_1 = LoRALinear(
            config.n_embd,
            config.intermediate_size,
            bias=config.bias,
            r=(config.r if config.to_mlp else 0),
            lora_alpha=config.alpha,
            lora_dropout=config.dropout,
        )
        self.fc_2 = LoRALinear(
            config.n_embd,
            config.intermediate_size,
            bias=config.bias,
            r=(config.r if config.to_mlp else 0),
            lora_alpha=config.alpha,
            lora_dropout=config.dropout,
        )
        self.proj = LoRALinear(
            config.intermediate_size,
            config.n_embd,
            bias=config.bias,
            r=(config.r if config.to_mlp else 0),
            lora_alpha=config.alpha,
            lora_dropout=config.dropout,
        )

    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
        """For compatibility with base checkpoints."""
        mapping = {
            "fc_1.weight": "fc_1.linear.weight",
            "fc_1.bias": "fc_1.linear.bias",
            "fc_2.weight": "fc_2.linear.weight",
            "fc_2.bias": "fc_2.linear.bias",
            "proj.weight": "proj.linear.weight",
            "proj.bias": "proj.linear.bias",
        }
        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


def merge_lora_weights(model: GPT) -> None:
    """Merge LoRA weights into the full-rank weights to speed up inference."""
    for module in model.modules():
        if isinstance(module, LoRALinear):
            module.merge()


================================================
FILE: lit_gpt/model.py
================================================
"""Full definition of a GPT NeoX Language Model, all of it in this single file.

Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
"""
import math
from typing import Any, List, Optional, Tuple

import torch
import torch.nn as nn
from lightning_utilities.core.imports import RequirementCache
from typing_extensions import Self
from flash_attn import flash_attn_func
from lit_gpt.config import Config
from xformers.ops import SwiGLU
from .fused_rotary_embedding import apply_rotary_emb_func
RoPECache = Tuple[torch.Tensor, torch.Tensor]
KVCache = Tuple[torch.Tensor, torch.Tensor]
FlashAttention2Available = RequirementCache("flash-attn>=2.0.0.post1")


class GPT(nn.Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        assert config.padded_vocab_size is not None
        self.config = config

        self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
                h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
                ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
            )
        )
        self.rope_cache: Optional[RoPECache] = None
        self.mask_cache: Optional[torch.Tensor] = None
        self.kv_caches: List[KVCache] = []

    def _init_weights(self, module: nn.Module, n_layer) -> None:
        """Meant to be used with `gpt.apply(gpt._init_weights)`."""
        # GPT-NeoX  https://arxiv.org/pdf/2204.06745.pdf
        if isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
            # RWKV: set it to 1e-4
            # torch.nn.init.uniform_(module.weight,  -1e-4, 1e-4)
        elif isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        # GPT-NeoX       
        for name, p in module.named_parameters():
            if (name == "proj.weight" and isinstance(module, LLaMAMLP)) or (name == "w3.weight" and isinstance(module, SwiGLU) or (name=="proj.weight" and isinstance(module, CausalSelfAttention))):  #if use xformer swiglu, fc2 layer will be renamed to w3
                nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd)  /  n_layer)
        

    def reset_cache(self) -> None:
        self.kv_caches.clear()
        if self.mask_cache is not None and self.mask_cache.device.type == "xla":
            # https://github.com/Lightning-AI/lit-gpt/pull/83#issuecomment-1558150179
            self.rope_cache = None
            self.mask_cache = None

    def forward(
        self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        B, T = idx.size()
        use_kv_cache = input_pos is not None

        block_size = self.config.block_size
        if max_seq_length is None:
            max_seq_length = block_size
        if use_kv_cache:  # not relevant otherwise
            assert (
                max_seq_length >= T
            ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
        assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
        assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}"

        if self.rope_cache is None:
            self.rope_cache = self.build_rope_cache(idx)
        # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask
        # for the kv-cache support (only during inference), we only create it in that situation
        # this will be resolved by https://github.com/pytorch/pytorch/issues/96099
        if use_kv_cache and self.mask_cache is None:
            self.mask_cache = self.build_mask_cache(idx)

        cos, sin = self.rope_cache
        if use_kv_cache:

            cos = cos.index_select(0, input_pos)
            sin = sin.index_select(0, input_pos)
            mask = self.mask_cache.index_select(2, input_pos)
            mask = mask[:, :, :, :max_seq_length]
        else:
            cos = cos[:T]
            sin = sin[:T]
            mask = None

        # forward the model itself
        x = self.transformer.wte(idx)  # token embeddings of shape (b, t, n_embd)
            
        if not use_kv_cache:
            for block in self.transformer.h:
                x, *_ = block(x, (cos, sin), max_seq_length)
        else:
            self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1) * 2)
            for i, block in enumerate(self.transformer.h):
                x, self.kv_caches[i] = block(x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i])

        x = self.transformer.ln_f(x)

        return self.lm_head(x)  # (b, t, vocab_size)

    @classmethod
    def from_name(cls, name: str, **kwargs: Any) -> Self:
        return cls(Config.from_name(name, **kwargs))

    def build_rope_cache(self, idx: torch.Tensor) -> RoPECache:
        return build_rope_cache(
            seq_len=self.config.block_size,
            n_elem=int(self.config.rotary_percentage * self.config.head_size),
            dtype=torch.bfloat16,
            device=idx.device,
            condense_ratio=self.config.condense_ratio,
        )

    def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor:
        ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool)
        return torch.tril(ones).unsqueeze(0).unsqueeze(0)

    def build_kv_caches(self, idx: torch.Tensor, max_seq_length: int, rope_cache_length: int) -> List[KVCache]:
        B = idx.size(0)
        heads = 1 if self.config.n_query_groups == 1 else self.config.n_query_groups

        k_cache_shape = (
            B,
            max_seq_length,
            heads,
            rope_cache_length + self.config.head_size - int(self.config.rotary_percentage * self.config.head_size),
        )
        v_cache_shape = (B, max_seq_length, heads, self.config.head_size)
        device = idx.device
        return [
            (torch.zeros(k_cache_shape, device=device), torch.zeros(v_cache_shape, device=device))
            for _ in range(self.config.n_layer)
        ]


class Block(nn.Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
        self.attn = CausalSelfAttention(config)
        if not config.shared_attention_norm:
            self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
        self.mlp = config.mlp_class(config)
        self.config = config
    def forward(
        self,
        x: torch.Tensor,
        rope: RoPECache,
        max_seq_length: int,
        mask: Optional[torch.Tensor] = None,
        input_pos: Optional[torch.Tensor] = None,
        kv_cache: Optional[KVCache] = None,
    ) -> Tuple[torch.Tensor, Optional[KVCache]]:

        n_1 = self.norm_1(x)
        h, new_kv_cache = self.attn(n_1, rope, max_seq_length, mask, input_pos, kv_cache)
        if self.config.parallel_residual:
            n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
            x = x + h + self.mlp(n_2)
        else:
            if self.config.shared_attention_norm:
                raise NotImplementedError(
                    "No checkpoint amongst the ones we support uses this configuration"
                    " (non-parallel residual and shared attention norm)."
                )
            
            x = x + h
            x = x + self.mlp(self.norm_2(x))
        return x, new_kv_cache


class CausalSelfAttention(nn.Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
        # key, query, value projections for all heads, but in a batch
        self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
        # output projection
        self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        self.config = config

    def forward(
        self,
        x: torch.Tensor,
        rope: RoPECache,
        max_seq_length: int,
        mask: Optional[torch.Tensor] = None,
        input_pos: Optional[torch.Tensor] = None,
        kv_cache: Optional[KVCache] = None,
    ) -> Tuple[torch.Tensor, Optional[KVCache]]:
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

        qkv = self.attn(x)

        # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
        q_per_kv = self.config.n_head // self.config.n_query_groups
        total_qkv = q_per_kv + 2  # each group has 1+ queries, 1 key, and 1 value
        qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) # (B, T, n_query_groups, total_qkv, hs)
        # qkv = qkv.permute(0, 2, 3, 1, 4)  # (B, n_query_groups, total_qkv, T, hs)

        # split batched computation into three
        q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2)

        # repeat k and v if necessary
        # Peiyuan: we do not need to do this as flash attention 2 already support GQA
        # if self.config.n_query_groups != 1:  # doing this would require a full kv cache with MQA (inefficient!)
        #     # for MHA this is a no-op
        #     k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
        #     v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)

        q = q.reshape(B,  T, -1, self.config.head_size)  # (B, T, nh_q, hs)
        k = k.reshape(B,  T, -1, self.config.head_size)  
        v = v.reshape(B,  T, -1, self.config.head_size)  

        cos, sin = rope

        # apply rope in fp32 significanly stabalize training
        # fused rope expect (batch_size, seqlen, nheads, headdim)
        q = apply_rotary_emb_func(q, cos, sin, False, True)
        k = apply_rotary_emb_func(k, cos, sin, False, True)
        
        # n_elem = int(self.config.rotary_percentage * self.config.head_size)
    
        # q_roped = apply_rope(q[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2))
        # k_roped = apply_rope(k[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2))
        # print( (q_roped - q).sum())
        # q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
        # k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)

        if kv_cache is not None:
            cache_k, cache_v = kv_cache
            cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype)
            # check if reached token limit
            if input_pos[-1] >= max_seq_length:
                input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
                # shift 1 position to the left
                cache_k = torch.roll(cache_k, -1, dims=1)
                cache_v = torch.roll(cache_v, -1, dims=1)

            k = cache_k.index_copy_(1, input_pos, k)
            v = cache_v.index_copy_(1, input_pos, v)
            kv_cache = k, v

        y = self.scaled_dot_product_attention(q, k, v, mask=mask)

        y = y.reshape(B, T, C)  # re-assemble all head outputs side by side

        # output projection
        y = self.proj(y)

        return y, kv_cache

    def scaled_dot_product_attention(
        self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
    ):
        scale = 1.0 / math.sqrt(self.config.head_size)
        
        if (
            FlashAttention2Available
            and mask is None
            and q.device.type == "cuda"
            and q.dtype in (torch.float16, torch.bfloat16)
        ):
            from flash_attn import flash_attn_func

            return flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=scale, causal=True)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        if q.size() != k.size():
             k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1)
             v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1)
        y = torch.nn.functional.scaled_dot_product_attention(
            q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
        )
        return y.transpose(1, 2)


class GptNeoxMLP(nn.Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
        self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc(x)
        x = torch.nn.functional.gelu(x)
        return self.proj(x)


class LLaMAMLP(nn.Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        # self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
        # self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
        # self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
        self.swiglu = SwiGLU(config.n_embd,config.intermediate_size, bias=False, _pack_weights=False)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x_fc_1 = self.fc_1(x)
        # x_fc_2 = self.fc_2(x)
        # x = torch.nn.functional.silu(x_fc_1) * x_fc_2
        # return self.proj(x)
        return self.swiglu(x)


def build_rope_cache(
    seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1
) -> RoPECache:
    """Enhanced Transformer with Rotary Position Embedding.

    Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
    transformers/rope/__init__.py. MIT License:
    https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
    """
    # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
    theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem))

    # Create position indexes `[0, 1, ..., seq_len - 1]`
    seq_idx = torch.arange(seq_len, device=device) / condense_ratio

    # Calculate the product of position index and $\theta_i$
    idx_theta = torch.outer(seq_idx, theta)

    cos, sin = torch.cos(idx_theta), torch.sin(idx_theta)

    # added by peiyuan to ensure same data type with q, k, to use fused rotary embedding
    if dtype == torch.bfloat16:
        return cos.bfloat16(), sin.bfloat16()
    # this is to mimic the behaviour of complex32, else we will get different results
    if dtype in (torch.float16, torch.bfloat16, torch.int8):
        return cos.half(), sin.half()
    return cos, sin


def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    head_size = x.size(-1)
    x1 = x[..., : head_size // 2]  # (B, nh, T, hs/2)
    x2 = x[..., head_size // 2 :]  # (B, nh, T, hs/2)
    rotated = torch.cat((-x2, x1), dim=-1)  # (B, nh, T, hs)
    roped = (x * cos) + (rotated * sin)
    return roped.type_as(x)


================================================
FILE: lit_gpt/packed_dataset.py
================================================
# Very loosely inspired by indexed_dataset in Fairseq, Megatron
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py


import os
import random
import struct

import numpy as np
import torch
from torch.utils.data import IterableDataset, get_worker_info

dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16}


def code(dtype):
    for k in dtypes:
        if dtypes[k] == dtype:
            return k
    raise ValueError(dtype)


HDR_MAGIC = b"LITPKDS"
HDR_SIZE = 24  # bytes


class PackedDataset(IterableDataset):
    def __init__(
        self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0
    ):
        self._filenames = filenames
        self._n_chunks = n_chunks
        self._block_size = block_size
        self._seed = seed
        self._shuffle = shuffle
        self._wrap = wrap
        self._num_processes = num_processes
        self._process_rank = process_rank

    def __iter__(self):
        worker_info = get_worker_info()
        num_workers = worker_info.num_workers if worker_info is not None else 1
        worker_id = worker_info.id if worker_info is not None else 0
        num_shards = num_workers * self._num_processes
        shard_id = self._process_rank * num_workers + worker_id

        max_num_files = len(self._filenames) // num_shards * num_shards
        filenames = self._filenames[shard_id:max_num_files:num_shards]

        return PackedDatasetIterator(
            filenames=filenames,
            n_chunks=self._n_chunks,
            block_size=self._block_size,
            seed=self._seed,
            shuffle=self._shuffle,
            wrap=self._wrap,
        )


class PackedDatasetBuilder(object):
    def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None):
        if dtype == "auto":
            if vocab_size is None:
                raise ValueError("vocab_size cannot be None when dtype='auto'")
            if vocab_size is not None and vocab_size < 65500:
                self._dtype = np.uint16
            else:
                self._dtype = np.int32
        else:
            self._dtype = dtype
        self._counter = 0
        self._chunk_size = chunk_size
        self._outdir = outdir
        self._prefix = prefix
        self._sep_token = sep_token
        self._arr = np.zeros(self._chunk_size, dtype=self._dtype)
        self._arr.fill(self._sep_token)
        self._idx = 0
        self._version = 1
        self._filenames = []

    def _write_chunk(self):
        filename = f"{self._prefix}_{self._counter:010d}.bin"
        filename = os.path.join(self._outdir, filename)

        with open(filename, "wb") as f:
            f.write(HDR_MAGIC)
            f.write(struct.pack("<Q", self._version))
            f.write(struct.pack("<B", code(self._dtype)))
            f.write(struct.pack("<Q", self._chunk_size))
            f.write(self._arr.tobytes(order="C"))

        self._filenames.append(filename)
        self._counter += 1
        self._arr.fill(self._sep_token)
        self._idx = 0

    @property
    def dtype(self):
        return self._dtype

    @property
    def filenames(self):
        return self._filenames.copy()

    def add_array(self, arr):
        while self._idx + arr.shape[0] > self._chunk_size:
            part_len = self._chunk_size - self._idx
            self._arr[self._idx : self._idx + part_len] = arr[:part_len]
            self._write_chunk()
            arr = arr[part_len:]

        arr_len = arr.shape[0]
        self._arr[self._idx : self._idx + arr_len] = arr
        self._idx += arr_len

    def write_reminder(self):
        self._write_chunk()


class PackedDatasetIterator:
    def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap):
        self._seed = seed
        self._shuffle = shuffle
        self._rng = np.random.default_rng(seed) if shuffle else None
        self._block_idxs = None

        self._wrap = wrap

        # TODO: instead of filenames, we could have a single text stream
        #       (or text file) with the sequence of all files to be
        #       fetched/loaded.
        self._filenames = filenames
        self._file_idx = 0

        self._n_chunks = n_chunks

        self._dtype = None
        self._block_size = block_size
        self._n_blocks = None

        self._mmaps = []
        self._buffers = []

        self._block_idxs = []
        self._curr_idx = 0

        self._load_n_chunks()

    def _read_header(self, path):
        with open(path, "rb") as f:
            magic = f.read(len(HDR_MAGIC))
            assert magic == HDR_MAGIC, "File doesn't match expected format."
            version = struct.unpack("<Q", f.read(8))
            assert version == (1,)
            (dtype_code,) = struct.unpack("<B", f.read(1))
            dtype = dtypes[dtype_code]
            (chunk_size,) = struct.unpack("<Q", f.read(8))
        return dtype, chunk_size

    def _close_mmaps(self):
        for mmap in self._mmaps:
            mmap._mmap.close()

    def _load_n_chunks(self):
        self._close_mmaps()
        self._mmaps = []
        self._buffers = []

        if self._n_chunks > len(self._filenames[self._file_idx :]):
            # if not self._wrap:
            #     raise StopIteration
            self._file_idx = 0

        for i in range(self._n_chunks):
            filename = self._filenames[self._file_idx + i]
            if self._dtype is None:
                self._dtype, self._chunk_size = self._read_header(filename)
                self._n_blocks = self._chunk_size // self._block_size
            # TODO: check header matches with previous files
            mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
            self._mmaps.append(mmap)
            self._buffers.append(memoryview(mmap))

        self._file_idx += self._n_chunks
        n_all_blocks = self._n_chunks * self._n_blocks

        self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks)

        self._curr_idx = 0

    def __del__(self):
        self._close_mmaps()
        del self._mmaps
        del self._buffers

    def __iter__(self):
        return self

    def __next__(self):
        if self._curr_idx >= len(self._block_idxs):
            self._load_n_chunks()
            # TODO: trigger fetching next next n_chunks if remote
        block_idx = self._block_idxs[self._curr_idx]
        chunk_id = block_idx // self._n_blocks
        buffer = self._buffers[chunk_id]
        elem_id = (block_idx % self._n_blocks) * self._block_size
        offset = np.dtype(self._dtype).itemsize * elem_id
        arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
        self._curr_idx += 1
        return torch.from_numpy(arr.astype(np.int64))


class CombinedDataset(IterableDataset):
    def __init__(self, datasets, seed, weights=None):
        self._seed = seed
        self._datasets = datasets
        self._weights = weights
        n_datasets = len(datasets)
        if weights is None:
            self._weights = [1 / n_datasets] * n_datasets

    def __iter__(self):
        return CombinedDatasetIterator(self._datasets, self._seed, self._weights)


class CombinedDatasetIterator:
    def __init__(self, datasets, seed, weights):
        self._datasets = [iter(el) for el in datasets]
        self._weights = weights
        self._rng = random.Random(seed)

    def __next__(self):
        (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1)
        return next(dataset)


================================================
FILE: lit_gpt/rmsnorm.py
================================================
import torch
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py AND https://github.com/Dao-AILab/flash-attention/blob/7a983df74215e035e566e37125b0a71e3618f39d/flash_attn/ops/layer_norm.py#L16

import dropout_layer_norm
import torch
from torch.nn import init


def maybe_align(x, alignment_in_bytes=16):
    """Assume that x already has last dim divisible by alignment_in_bytes"""
    # TD [2023-07-04] I'm not 100% sure that clone will align the memory
    # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
    return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()


def _dropout_add_layer_norm_forward(
    x0,
    residual,
    gamma,
    beta,
    rowscale,
    colscale,
    dropout_p,
    epsilon,
    residual_in_fp32=False,
    is_rms_norm=False,
):
    """Assume that arguments are contiguous and aligned to 16 bytes"""
    hidden_size = gamma.numel()
    x0mat = x0.view((-1, hidden_size))
    residualmat = residual.view((-1, hidden_size)) if residual is not None else None
    rowscale = rowscale.view(-1) if rowscale is not None else None
    zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
        x0mat,
        residualmat,
        gamma,
        beta,
        rowscale,
        colscale,
        None,
        None,
        dropout_p,
        epsilon,
        1.0,
        0,
        None,
        residual_in_fp32,
        is_rms_norm,
    )
    # dmask is None if dropout_p == 0.0
    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
    return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma


def _dropout_add_layer_norm_backward(
    dz,
    dx,
    x,
    x0,
    dmask,
    mu,
    rsigma,
    gamma,
    rowscale,
    colscale,
    dropout_p,
    has_residual,
    is_rms_norm=False,
):
    """Assume that arguments are contiguous and aligned to 16 bytes
    dx == None means that it was a post-norm architecture
    (x = drop(x0) + residual was not returned in the fwd).
    x0 must not be None if we have colscale.
    """
    hidden_size = gamma.numel()
    xmat = x.view((-1, hidden_size))
    dzmat = dz.view(xmat.shape)
    dxmat = dx.view(xmat.shape) if dx is not None else None
    x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
    rowscale = rowscale.view(-1) if rowscale is not None else None
    if colscale is not None:
        assert x0 is not None, "x0 is required to compute the gradient of colscale"
    dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
        dzmat,
        dxmat,
        xmat,
        x0mat,
        dmask,
        mu,
        rsigma,
        gamma,
        rowscale,
        colscale,
        None,
        None,
        dropout_p,
        1.0,
        0,
        has_residual,
        is_rms_norm,
    )
    # dresidualmat is None if not has_residual
    if colscale is None:
        return dx0mat, dresidualmat, dgamma, dbeta
    else:
        dcolscale = rest[0]
        return dx0mat, dresidualmat, dgamma, dbeta, dcolscale


def _dropout_add_layer_norm_subset_forward(
    x0,
    residual,
    gamma,
    beta,
    colscale,
    x0_subset,
    out_subset,
    dropout_p,
    epsilon,
    rowscale_const,
    out_numrows,
    residual_in_fp32=False,
    is_rms_norm=False,
):
    """Assume that arguments are contiguous and aligned to 16 bytes"""
    hidden_size = gamma.numel()
    x0mat = x0.view((-1, hidden_size))
    residualmat = residual.view((-1, hidden_size)) if residual is not None else None
    x0_subset = x0_subset.view(-1) if x0_subset is not None else None
    out_subset = out_subset.view(-1) if out_subset is not None else None
    zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
        x0mat,
        residualmat,
        gamma,
        beta,
        None,
        colscale,
        x0_subset,
        out_subset,
        dropout_p,
        epsilon,
        rowscale_const,
        out_numrows,
        None,
        residual_in_fp32,
        is_rms_norm,
    )
    # dmask is None if dropout_p == 0.0
    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
    return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma


def _dropout_add_layer_norm_subset_backward(
    dz,
    dx,
    x,
    x0,
    dmask,
    mu,
    rsigma,
    gamma,
    colscale,
    x0_subset,
    out_subset,
    dropout_p,
    rowscale_const,
    x0_numrows,
    has_residual,
    is_rms_norm=False,
):
    """Assume that arguments are contiguous and aligned to 16 bytes
    dx == None means that it was a post-norm architecture
    (x = drop(x0) + residual was not returned in the fwd).
    x0 must not be None if we have colscale.
    """
    hidden_size = gamma.numel()
    xmat = x.view((-1, hidden_size))
    dzmat = dz.view(-1, hidden_size)
    dxmat = dx.view(xmat.shape) if dx is not None else None
    x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
    x0_subset = x0_subset.view(-1) if x0_subset is not None else None
    out_subset = out_subset.view(-1) if out_subset is not None else None
    if colscale is not None:
        assert x0 is not None, "x0 is required to compute the gradient of colscale"
    dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
        dzmat,
        dxmat,
        xmat,
        x0mat,
        dmask,
        mu,
        rsigma,
        gamma,
        None,
        colscale,
        x0_subset,
        out_subset,
        dropout_p,
        rowscale_const,
        x0_numrows,
        has_residual,
        is_rms_norm,
    )
    # dresidualmat is None if not has_residual
    if colscale is None:
        return dx0mat, dresidualmat, dgamma, dbeta
    else:
        dcolscale = rest[0]
        return dx0mat, dresidualmat, dgamma, dbeta, dcolscale


def _dropout_add_layer_norm_parallel_residual_forward(
    x0,
    x1,
    residual,
    gamma0,
    beta0,
    gamma1,
    beta1,
    dropout_p,
    epsilon,
    residual_in_fp32=False,
    is_rms_norm=False,
):
    """Assume that arguments are contiguous and aligned to 16 bytes"""
    hidden_size = gamma0.numel()
    x0mat = x0.view((-1, hidden_size))
    x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
    residualmat = residual.view((-1, hidden_size)) if residual is not None else None
    (
        z0mat,
        z1mat,
        xmat,
        dmask0,
        dmask1,
        mu,
        rsigma,
    ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
        x0mat,
        x1mat,
        residualmat,
        gamma0,
        beta0,
        gamma1,
        beta1,
        dropout_p,
        epsilon,
        None,
        residual_in_fp32,
        is_rms_norm,
    )
    # dmask0 and dmask1 are None if dropout_p == 0.0
    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
    return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma


def _dropout_add_layer_norm_parallel_residual_backward(
    dz0,
    dz1,
    dx,
    x,
    dmask0,
    dmask1,
    mu,
    rsigma,
    gamma0,
    gamma1,
    dropout_p,
    has_x1,
    has_residual,
    is_rms_norm=False,
):
    """Assume that arguments are contiguous and aligned to 16 bytes
    dx == None means that it was a post-norm architecture
    (x = drop(x0) + residual was not returned in the fwd).
    """
    hidden_size = gamma0.numel()
    xmat = x.view((-1, hidden_size))
    dz0mat = dz0.view(xmat.shape)
    dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
    dxmat = dx.view(xmat.shape) if dx is not None else None
    (
        dx0mat,
        dx1mat,
        dresidualmat,
        dgamma0,
        dbeta0,
        dgamma1,
        dbeta1,
        *rest,
    ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
        dz0mat,
        dz1mat,
        dxmat,
        xmat,
        dmask0,
        dmask1,
        mu,
        rsigma,
        gamma0,
        gamma1,
        dropout_p,
        has_x1,
        has_residual,
        is_rms_norm,
    )
    # dresidualmat is None if not has_residual
    return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1


class DropoutAddLayerNormFn(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        x0,
        residual,
        gamma,
        beta,
        rowscale,
        colscale,
        dropout_p,
        epsilon,
        residual_in_fp32=False,
        prenorm=False,
        is_rms_norm=False,
        return_dmask=False,
    ):
        x0 = maybe_align(x0.contiguous(), 16)
        residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
        gamma = maybe_align(gamma.contiguous(), 16)
        beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
        rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
        colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
            x0,
            residual,
            gamma,
            beta,
            rowscale,
            colscale,
            dropout_p,
            epsilon,
            residual_in_fp32,
            is_rms_norm,
        )
        # Only need to save x0 if we need to compute gradient wrt colscale
        x0_saved = x0 if colscale is not None else None
        ctx.save_for_backward(
            xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale
        )
        ctx.prenorm = prenorm
        ctx.dropout_p = dropout_p
        ctx.has_residual = residual is not None
        ctx.is_rms_norm = is_rms_norm
        ctx.has_beta = beta is not None
        if not return_dmask:
            return (
                zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape))
            )
        else:
            dmask = (
                dmask.view(x0.shape)
                if dropout_p > 0.0
                else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
            )
            ctx.mark_non_differentiable(dmask)
            return (
                (zmat.view(x0.shape), dmask)
                if not prenorm
                else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)
            )

    @staticmethod
    def backward(ctx, dz, *args):
        # assert dz.is_contiguous()
        dz = maybe_align(dz.contiguous(), 16)  # this happens!
        dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
        x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
        # x0 is None if colscale is None
        dropout_p = ctx.dropout_p
        has_residual = ctx.has_residual
        dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
            dz,
            dx,
            x,
            x0,
            dmask,
            mu,
            rsigma,
            gamma,
            rowscale,
            colscale,
            dropout_p,
            has_residual,
            ctx.is_rms_norm,
        )
        dx0 = dx0mat.view(x.shape)
        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
        dcolscale = rest[0] if colscale is not None else None
        return (
            dx0,
            dresidual,
            dgamma,
            dbeta if ctx.has_beta else None,
            None,
            dcolscale,
            None,
            None,
            None,
            None,
            None,
            None,
        )


class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        x0,
        residual,
        gamma,
        beta,
        colscale,
        x0_subset,
        out_subset,
        dropout_p,
        epsilon,
        rowscale_const,
        out_numrows,
        residual_in_fp32=False,
        prenorm=False,
        is_rms_norm=False,
        return_dmask=False,
    ):
        x0 = maybe_align(x0.contiguous(), 16)
        residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
        gamma = maybe_align(gamma.contiguous(), 16)
        beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
        colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
            x0,
            residual,
            gamma,
            beta,
            colscale,
            x0_subset,
            out_subset,
            dropout_p,
            epsilon,
            rowscale_const,
            out_numrows,
            residual_in_fp32,
            is_rms_norm,
        )
        # Only need to save x0 if we need to compute gradient wrt colscale
        x0_saved = x0 if colscale is not None else None
        x_shape = (-1, *x0.shape[1:])
        ctx.save_for_backward(
            xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset
        )
        ctx.prenorm = prenorm
        ctx.dropout_p = dropout_p
        ctx.rowscale_const = rowscale_const
        ctx.x0_numrows = x0.shape[:-1].numel()
        ctx.has_residual = residual is not None
        ctx.is_rms_norm = is_rms_norm
        ctx.has_beta = beta is not None
        z_shape = (-1, *x0.shape[1:])
        if not return_dmask:
            return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape))
        else:
            z = zmat.view(z_shape)
            dmask = (
                dmask.view(x0.shape)
                if dropout_p > 0.0
                else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
            )
            ctx.mark_non_differentiable(dmask)
            return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)

    @staticmethod
    def backward(ctx, dz, *args):
        # assert dz.is_contiguous()
        dz = maybe_align(dz.contiguous(), 16)  # this happens!
        dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
        x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
        # x0 is None if colscale is None
        dropout_p = ctx.dropout_p
        has_residual = ctx.has_residual
        dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
            dz,
            dx,
            x,
            x0,
            dmask,
            mu,
            rsigma,
            gamma,
            colscale,
            x0_subset,
            out_subset,
            dropout_p,
            ctx.rowscale_const,
            ctx.x0_numrows,
            has_residual,
            ctx.is_rms_norm,
        )
        dx0 = dx0mat.view(-1, *x.shape[1:])
        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
        dcolscale = rest[0] if colscale is not None else None
        return (
            dx0,
            dresidual,
            dgamma,
            dbeta if ctx.has_beta else None,
            dcolscale,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )


class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        x0,
        x1,
        residual,
        gamma0,
        beta0,
        gamma1,
        beta1,
        dropout_p,
        epsilon,
        residual_in_fp32=False,
        prenorm=False,
        is_rms_norm=False,
        return_dmask=False,
    ):
        x0 = maybe_align(x0.contiguous(), 16)
        x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
        residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
        gamma0 = maybe_align(gamma0.contiguous(), 16)
        beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
        gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
        beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
        (
            z0mat,
            z1mat,
            xmat,
            dmask0,
            dmask1,
            mu,
            rsigma,
        ) = _dropout_add_layer_norm_parallel_residual_forward(
            x0,
            x1,
            residual,
            gamma0,
            beta0,
            gamma1,
            beta1,
            dropout_p,
            epsilon,
            residual_in_fp32,
            is_rms_norm,
        )
        ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
        ctx.prenorm = prenorm
        ctx.dropout_p = dropout_p
        ctx.has_x1 = x1 is not None
        ctx.has_residual = residual is not None
        ctx.is_rms_norm = is_rms_norm
        ctx.has_beta = beta0 is not None
        z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
        if not return_dmask:
            return z if not prenorm else (*z, xmat.view(x0.shape))
        else:
            dmask0 = (
                dmask0.view(x0.shape)
                if dropout_p > 0.0
                else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
            )
            dmask1 = (
                dmask1.view(x0.shape)
                if dropout_p > 0.0 and x1 is not None
                else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
            )
            ctx.mark_non_differentiable(dmask0)
            ctx.mark_non_differentiable(dmask1)
            return (
                (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
            )

    @staticmethod
    def backward(ctx, dz0, dz1, *args):
        dz0 = maybe_align(dz0.contiguous(), 16)  # this happens!
        dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
        dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
        x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
        dropout_p = ctx.dropout_p
        has_x1 = ctx.has_x1
        has_residual = ctx.has_residual
        (
            dx0mat,
            dx1mat,
            dresidualmat,
            dgamma0,
            dbeta0,
            dgamma1,
            dbeta1,
        ) = _dropout_add_layer_norm_parallel_residual_backward(
            dz0,
            dz1,
            dx,
            x,
            dmask0,
            dmask1,
            mu,
            rsigma,
            gamma0,
            gamma1,
            dropout_p,
            has_x1,
            has_residual,
            ctx.is_rms_norm,
        )
        dx0 = dx0mat.view(x.shape)
        dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
        return (
            dx0,
            dx1,
            dresidual,
            dgamma0,
            dbeta0 if ctx.has_beta else None,
            dgamma1,
            dbeta1 if ctx.has_beta else None,
            None,
            None,
            None,
            None,
            None,
            None,
        )


def layer_norm(x, weight, bias, epsilon):
    return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)


def dropout_add_layer_norm(
    x0,
    residual,
    weight,
    bias,
    dropout_p,
    epsilon,
    rowscale=None,
    layerscale=None,
    prenorm=False,
    residual_in_fp32=False,
    return_dropout_mask=False,
):
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
    """
    return DropoutAddLayerNormFn.apply(
        x0,
        residual,
        weight,
        bias,
        rowscale,
        layerscale,
        dropout_p,
        epsilon,
        residual_in_fp32,
        prenorm,
        False,
        return_dropout_mask,
    )


def dropout_add_layer_norm_subset(
    x0,
    residual,
    weight,
    bias,
    dropout_p,
    epsilon,
    layerscale=None,
    x0_subset=None,
    out_subset=None,
    rowscale_const=1.0,
    out_numrows=0,
    prenorm=False,
    residual_in_fp32=False,
    return_dropout_mask=False,
):
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
    """
    return DropoutAddLayerNormSubsetFn.apply(
        x0,
        residual,
        weight,
        bias,
        layerscale,
        x0_subset,
        out_subset,
        dropout_p,
        epsilon,
        rowscale_const,
        out_numrows,
        residual_in_fp32,
        prenorm,
        False,
        return_dropout_mask,
    )


def dropout_add_layer_norm_parallel_residual(
    x0,
    x1,
    residual,
    weight0,
    bias0,
    weight1,
    bias1,
    dropout_p,
    epsilon,
    prenorm=False,
    residual_in_fp32=False,
    return_dropout_mask=False,
):
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
    """
    return DropoutAddLayerNormParallelResidualFn.apply(
        x0,
        x1,
        residual,
        weight0,
        bias0,
        weight1,
        bias1,
        dropout_p,
        epsilon,
        residual_in_fp32,
        prenorm,
        False,
        return_dropout_mask,
    )


class DropoutAddLayerNorm(torch.nn.Module):
    def __init__(
        self,
        hidden_size,
        prenorm=False,
        p=0.0,
        eps=1e-5,
        residual_in_fp32=False,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.prenorm = prenorm
        self.p = p
        self.eps = eps
        self.residual_in_fp32 = residual_in_fp32
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)
        init.zeros_(self.bias)

    def forward(self, x0, residual=None):
        return dropout_add_layer_norm(
            x0,
            residual,
            self.weight,
            self.bias,
            self.p if self.training else 0.0,
            self.eps,
            prenorm=self.prenorm,
            residual_in_fp32=self.residual_in_fp32,
        )
        
def rms_norm(x, weight, epsilon):
    return DropoutAddLayerNormFn.apply(
        x, None, weight, None, None, None, 0.0, epsilon, False, False, True
    )
class FusedRMSNorm(torch.nn.Module):
    def __init__(self, size: int, dim: int = -1, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = torch.nn.Parameter(torch.ones(size))
        self.dim = dim
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)

    def forward(self, x):
        return rms_norm(x, self.weight, self.eps)
    
    
class RMSNorm(torch.nn.Module):
    """Root Mean Square Layer Normalization.

    Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
    https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
    """

    def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(size))
        self.eps = eps
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # NOTE: the original RMSNorm paper implementation is not equivalent
        norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
        x_normed = x * torch.rsqrt(norm_x + self.eps)
        return self.weight * x_normed

    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)


================================================
FILE: lit_gpt/speed_monitor.py
================================================
import time
from collections import deque
from contextlib import nullcontext
from typing import Any, Callable, Deque, Dict, Optional

import torch
from lightning import Callback, Fabric, LightningModule, Trainer
from lightning.fabric.utilities.rank_zero import rank_zero_only as fabric_rank_zero_only
from lightning.pytorch.utilities.rank_zero import rank_zero_only as trainer_rank_zero_only
from torch.utils.flop_counter import FlopCounterMode
import math
from lit_gpt import GPT, Config
from lit_gpt.utils import num_parameters

GPU_AVAILABLE_FLOPS = {
    # source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
    # nvidia publishes spec sheet with a 2x sparsity factor
    "h100-sxm": {
        "64-true": 67e12,
        "32-true": 67e12,
        "16-true": 1.979e15 / 2,
        "16-mixed": 1.979e15 / 2,
        "bf16-true": 1.979e15 / 2,
        "bf16-mixed": 1.979e15 / 2,
        "8-true": 3.958e15 / 2,
        "8-mixed": 3.958e15 / 2,
    },
    "h100-pcie": {
        "64-true": 51e12,
        "32-true": 51e12,
        "16-true": 1.513e15 / 2,
        "16-mixed": 1.513e15 / 2,
        "bf16-true": 1.513e15 / 2,
        "bf16-mixed": 1.513e15 / 2,
        "8-true": 3.026e15 / 2,
        "8-mixed": 3.026e15 / 2,
    },
    # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
    # sxm and pcie have same flop counts
    "a100": {
        "64-true": 19.5e12,
        "32-true": 19.5e12,
        "16-true": 312e12,
        "16-mixed": 312e12,
        "bf16-true": 312e12,
        "bf16-mixed": 312e12,
    },
    # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/a10-datasheet.pdf
    "a10g": {"32-true": 31.2e12, "16-true": 125e12, "16-mixed": 125e12, "bf16-true": 125e12, "bf16-mixed": 125e12},
    # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf
    "v100-sxm": {"64-true": 7.8e12, "32-true": 15.7e12, "16-true": 125e12, "16-mixed": 125e12},
    "v100-pcie": {"64-true": 7e12, "32-true": 14e12, "16-true": 112e12, "16-mixed": 112e12},
    "v100s-pcie": {"64-true": 8.2e12, "32-true": 16.4e12, "16-true": 130e12, "16-mixed": 130e12},
    # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf
    # sxm and pcie have same flop counts
    "t4": {"32-true": 8.1e12, "16-true": 65e12, "16-mixed": 65e12, "8-true": 130e12, "int4": 260e12},
    # https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/quadro-rtx-5000-data-sheet-us-nvidia-704120-r4-web.pdf
    "quadro rtx 5000": {"32-true": 11.2e12, "16-true": 89.2e12, "16-mixed": 89.2e12},
}

TPU_AVAILABLE_FLOPS = {
    # flop count for each TPU generation is the same for all precisions
    # since bfloat16 precision is always used for performing matrix operations
    # for more info: https://cloud.google.com/tpu/docs/bfloat16#choosing_bfloat16
    # source: https://arxiv.org/pdf/1907.10701.pdf
    "v2": 45e12,
    # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v3
    "v3": 123e12,
    # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4
    "v4": 275e12,
}


def get_flops_available(device: torch.device, precision: str) -> Optional[float]:
    if device.type == "cuda":
        device_name = torch.cuda.get_device_name(device).lower()
        if "h100" in device_name and "hbm3" in device_name:
            device_name = "h100-sxm"
        elif "h100" in device_name and ("pcie" in device_name or "hbm2e" in device_name):
            device_name = "h100-pcie"
        elif "a100" in device_name:
            device_name = "a100"
        elif "a10g" in device_name:
            device_name = "a10g"
        elif "v100-sxm" in device_name:
            device_name = "v100-sxm"
        elif "v100-pcie" in device_name:
            device_name = "v100-pcie"
        elif "t4" in device_name:
            device_name = "t4"
        elif "quadro rtx 5000" in device_name:
            device_name = "quadro rtx 5000"
        else:
            device_name = None

        if device_name is not None:
            try:
                return int(GPU_AVAILABLE_FLOPS[device_name][precision])
            except KeyError:
                raise KeyError(
                    f"flop count not found for {device_name} with precision: {precision}; "
                    "MFU cannot be calculated and reported."
                )
    elif device.type == "xla":
        from torch_xla.experimental import tpu

        device_name = tpu.get_tpu_env()["TYPE"].lower()
        try:
            return int(TPU_AVAILABLE_FLOPS[device_name])
        except KeyError:
            raise KeyError(
                f"flop count not found for {device_name} with precision: {precision}; "
                "MFU cannot be calculated and reported."
            )

    return None


# Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820cb75023b9eb7c46fdfd25273712abd0/composer/callbacks/speed_monitor.py


class SpeedMonitorBase:
    """Logs the training throughput and utilization.

    +-------------------------------------+-----------------------------------------------------------+
    | Key                                 | Logged data                                               |
    +=====================================+===========================================================+
    |                                     | Rolling average (over `window_size` most recent           |
    | `throughput/batches_per_sec`        | batches) of the number of batches processed per second    |
    |                                     |                                                           |
    +-------------------------------------+-----------------------------------------------------------+
    |                                     | Rolling average (over `window_size` most recent           |
    | `throughput/samples_per_sec`        | batches) of the number of samples processed per second    |
    |                                     |                                                           |
    +-------------------------------------+-----------------------------------------------------------+
    |                                     | Rolling average (over `window_size` most recent           |
    | `throughput/tokens_per_sec`         | batches) of the number of tokens processed per second.    |
    |                                     | This may include padding depending on dataset             |
    +-------------------------------------+-----------------------------------------------------------+
    |                                     | Estimates flops by `flops_per_batch * batches_per_sec`    |
    | `throughput/flops_per_sec`          |                                                           |
    |                                     |                                                           |
    +-------------------------------------+-----------------------------------------------------------+
    | `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size        |
    +-------------------------------------+-----------------------------------------------------------+
    | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size        |
    +-------------------------------------+-----------------------------------------------------------+
    |                                     | `throughput/tokens_per_sec` divided by world size. This   |
    | `throughput/device/tokens_per_sec`  | may include pad tokens depending on dataset               |
    |                                     |                                                           |
    +-------------------------------------+-----------------------------------------------------------+
    |                                     | `throughput/flops_per_sec` divided by world size. Only    |
    | `throughput/device/flops_per_sec`   | logged when model has attribute `flops_per_batch`         |
    |                                     |                                                           |
    +-------------------------------------+-----------------------------------------------------------+
    |                                     | `throughput/device/flops_per_sec` divided by world size.  |
    | `throughput/device/mfu`             |                                                           |
    |                                     |                                                           |
    +-------------------------------------+-----------------------------------------------------------+
    | `time/train`                        | Total elapsed training time                               |
    +-------------------------------------+-----------------------------------------------------------+
    | `time/val`                          | Total elapsed validation time                             |
    +-------------------------------------+-----------------------------------------------------------+
    | `time/total`                        | Total elapsed time (time/train + time/val)                |
    +-------------------------------------+-----------------------------------------------------------+

    Notes:
        - The implementation assumes that devices are homogeneous as it normalizes by the world size.
        - Tokens/sec, flops/sec and MFU do not account for padding tokens if present. We suggest using samples/sec or
          batches/sec to measure throughput under this circumstance.
        - Be careful when comparing MFU numbers across projects, as this will highly depend on the ``flops_per_batch``.
          There is no widespread, realistic, and reliable implementation to compute them.
          We suggest using our ``measure_flops`` function, but many other works will use ``estimated_flops`` which
          will almost always be an overestimate when compared to the true value.

    Args:
        window_size (int, optional): Number of batches to use for a rolling average of throughput.
            Defaults to 100.
        time_unit (str, optional): Time unit to use for `time` logging. Can be one of
            'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'.
    """

    def __init__(
        self,
        flops_available: float,
        log_dict: Callable[[Dict, int], None],
        window_size: int = 100,
        time_unit: str = "hours",
        log_iter_interval: int = 1,
    ):
        self.flops_available = flops_available
        self.log_dict = log_dict
        self.log_iter_interval = log_iter_interval
        # Track the batch num samples and wct to compute throughput over a window of batches
        self.history_samples: Deque[int] = deque(maxlen=window_size + 1)
        self.history_training_loss: Deque[int] = deque(maxlen=log_iter_interval)
        self.history_wct: Deque[float] = deque(maxlen=window_size + 1)
        self.history_lengths: Deque[int] = deque(maxlen=window_size + 1)
        self.history_flops: Deque[int] = deque(maxlen=window_size + 1)

        self.divider = 1
        if time_unit == "seconds":
            self.divider = 1
        elif time_unit == "minutes":
            self.divider = 60
        elif time_unit == "hours":
            self.divider = 60 * 60
        elif time_unit == "days":
            self.divider = 60 * 60 * 24
        else:
            raise ValueError(
                f'Invalid time_unit: {time_unit}. Must be one of "seconds", "minutes", "hours", or "days".'
            )

        # Keep track of time spent evaluating
        self.total_eval_wct = 0.0
        self.iter = -1

    def on_train_batch_end(
        self,
        samples: int,  # total samples seen (per device)
        train_elapsed: float,  # total training time (seconds)
        world_size: int,
        step_count: int,
        flops_per_batch: Optional[int] = None,  # (per device)
        lengths: Optional[int] = None,  # total length of the samples seen (per device)
        train_loss: Optional[float] = None,
    ):
        self.iter += 1
        metrics = {}

        self.history_samples.append(samples)
        self.history_training_loss.append(train_loss)
        if lengths is not None:
            self.history_lengths.append(lengths)
            # if lengths are passed, there should be as many values as samples
            assert len(self.history_samples) == len(self.history_lengths)
        self.history_wct.append(train_elapsed)
        if len(self.history_wct) == self.history_wct.maxlen:
            elapsed_batches = len(self.history_samples) - 1
            elapsed_samples = self.history_samples[-1] - self.history_samples[0]
            elapsed_wct = self.history_wct[-1] - self.history_wct[0]
            samples_per_sec = elapsed_samples * world_size / elapsed_wct
            dev_samples_per_sec = elapsed_samples / elapsed_wct
            metrics.update(
                {
                    "throughput/batches_per_sec": elapsed_batches * world_size / elapsed_wct,
                    "throughput/samples_per_sec": samples_per_sec,
                    "throughput/device/batches_per_sec": elapsed_batches / elapsed_wct,
                    "throughput/device/samples_per_sec": dev_samples_per_sec,
                }
            )
            if lengths is not None:
                elapsed_lengths = int(self.history_lengths[-1]) - int(self.history_lengths[0])
                avg_length = elapsed_lengths / elapsed_batches  
                metrics.update(
                    {
                        "throughput/tokens_per_sec": samples_per_sec * avg_length,
                        "throughput/device/tokens_per_sec": dev_samples_per_sec * avg_length,
                        "total_tokens": avg_length * world_size * samples,
                    }
                )
                if train_loss is not None:
                    avg_loss = sum(self.history_training_loss) / len(self.history_training_loss)
                    metrics.update(
                        {
                            "metric/train_loss": avg_loss,
                            "metric/train_ppl": math.exp(avg_loss)
                        }
                    )

        if flops_per_batch is not None:
            # sum of flops per batch across ranks
            self.history_flops.append(flops_per_batch * world_size)
        if len(self.history_flops) == self.history_flops.maxlen:
            elapsed_flops = sum(self.history_flops) - self.history_flops[0]
            elapsed_wct = self.history_wct[-1] - self.history_wct[0]
            flops_per_sec = elapsed_flops / elapsed_wct
            device_flops_per_sec = flops_per_sec / world_size
            metrics.update(
                {"throughput/flops_per_sec": flops_per_sec, "throughput/device/flops_per_sec": device_flops_per_sec}
            )
            if self.flops_available:
                metrics["throughput/device/mfu"] = device_flops_per_sec / self.flops_available

        metrics.update(
            {
                "time/train": train_elapsed / self.divider,
                "time/val": self.total_eval_wct / self.divider,
                "time/total": (train_elapsed + self.total_eval_wct) / self.divider,
                "samples": samples,
            }
        )
        if self.iter % self.log_iter_interval == 0:
            self.log_dict(metrics, step_count)

    def eval_end(self, eval_elapsed: float):
        self.total_eval_wct += eval_elapsed  # seconds


class SpeedMonitorFabric(SpeedMonitorBase):
    def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None:
        # TODO: this will not work properly if a precision plugin is passed to Fabric
        flops_available = get_flops_available(fabric.device, fabric._connector.
Download .txt
gitextract_67kobddf/

├── .gitignore
├── EVAL.md
├── LICENSE
├── PRETRAIN.md
├── README.md
├── README_zh-CN.md
├── chat_gradio/
│   ├── README.md
│   ├── app.py
│   └── requirements.txt
├── lit_gpt/
│   ├── __init__.py
│   ├── adapter.py
│   ├── adapter_v2.py
│   ├── config.py
│   ├── fused_cross_entropy.py
│   ├── fused_rotary_embedding.py
│   ├── lora.py
│   ├── model.py
│   ├── packed_dataset.py
│   ├── rmsnorm.py
│   ├── speed_monitor.py
│   ├── tokenizer.py
│   └── utils.py
├── pretrain/
│   ├── tinyllama.py
│   └── tinyllama_code.py
├── requirements.txt
├── script.sh
├── scripts/
│   ├── convert_hf_checkpoint.py
│   ├── convert_lit_checkpoint.py
│   ├── prepare_redpajama.py
│   ├── prepare_slimpajama.py
│   └── prepare_starcoder.py
├── sft/
│   ├── finetune.py
│   ├── script.sh
│   ├── simple_inference.py
│   └── simple_inference2.py
└── speculative_decoding/
    ├── README.md
    └── instruct_hf_assisted_decoding.py
Download .txt
SYMBOL INDEX (296 symbols across 21 files)

FILE: chat_gradio/app.py
  class StopOnTokens (line 17) | class StopOnTokens(StoppingCriteria):
    method __call__ (line 18) | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTen...
  function predict (line 27) | def predict(message, history):

FILE: lit_gpt/adapter.py
  class Config (line 22) | class Config(BaseConfig):
  class GPT (line 27) | class GPT(BaseModel):
    method __init__ (line 31) | def __init__(self, config: Config) -> None:
    method reset_cache (line 50) | def reset_cache(self) -> None:
    method forward (line 54) | def forward(
    method from_name (line 115) | def from_name(cls, name: str, **kwargs: Any) -> Self:
    method _init_weights (line 118) | def _init_weights(self, module: nn.Module) -> None:
  class Block (line 125) | class Block(nn.Module):
    method __init__ (line 129) | def __init__(self, config: Config, block_idx: int) -> None:
    method forward (line 139) | def forward(
  class CausalSelfAttention (line 167) | class CausalSelfAttention(BaseCausalSelfAttention):
    method __init__ (line 171) | def __init__(self, config: Config, block_idx: int) -> None:
    method forward (line 181) | def forward(
    method reset_parameters (line 266) | def reset_parameters(self) -> None:
    method _load_from_state_dict (line 269) | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: ...
  function mark_only_adapter_as_trainable (line 276) | def mark_only_adapter_as_trainable(model: GPT) -> None:
  function adapter_filter (line 282) | def adapter_filter(key: str, value: Any) -> bool:

FILE: lit_gpt/adapter_v2.py
  class Config (line 26) | class Config(BaseConfig):
    method mlp_class (line 28) | def mlp_class(self) -> Type:
  function adapter_filter (line 32) | def adapter_filter(key: str, value: Any) -> bool:
  class AdapterV2Linear (line 48) | class AdapterV2Linear(torch.nn.Module):
    method __init__ (line 49) | def __init__(self, in_features: int, out_features: int, **kwargs) -> N...
    method forward (line 56) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method reset_parameters (line 59) | def reset_parameters(self) -> None:
  class GPT (line 64) | class GPT(BaseModel):
    method __init__ (line 65) | def __init__(self, config: Config) -> None:
    method from_name (line 86) | def from_name(cls, name: str, **kwargs: Any) -> Self:
    method _init_weights (line 89) | def _init_weights(self, module: nn.Module) -> None:
    method _load_from_state_dict (line 97) | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: ...
  class Block (line 104) | class Block(BaseBlock):
    method __init__ (line 108) | def __init__(self, config: Config, block_idx: int) -> None:
  class CausalSelfAttention (line 120) | class CausalSelfAttention(BaseCausalSelfAttention):
    method __init__ (line 121) | def __init__(self, config: Config, block_idx: int) -> None:
    method forward (line 145) | def forward(
    method reset_parameters (line 230) | def reset_parameters(self) -> None:
    method _load_from_state_dict (line 233) | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: ...
  class GptNeoxMLP (line 248) | class GptNeoxMLP(lit_gpt.model.GptNeoxMLP):
    method __init__ (line 249) | def __init__(self, config: Config) -> None:
    method _load_from_state_dict (line 254) | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: ...
  class LLaMAMLP (line 266) | class LLaMAMLP(lit_gpt.model.LLaMAMLP):
    method __init__ (line 267) | def __init__(self, config: Config) -> None:
    method _load_from_state_dict (line 273) | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: ...
  function mark_only_adapter_v2_as_trainable (line 287) | def mark_only_adapter_v2_as_trainable(model: GPT) -> None:

FILE: lit_gpt/config.py
  class Config (line 12) | class Config:
    method __post_init__ (line 53) | def __post_init__(self):
    method head_size (line 71) | def head_size(self) -> int:
    method from_name (line 75) | def from_name(cls, name: str, **kwargs: Any) -> Self:
    method mlp_class (line 81) | def mlp_class(self) -> Type:
    method norm_class (line 86) | def norm_class(self) -> Type:

FILE: lit_gpt/fused_cross_entropy.py
  class SoftmaxCrossEntropyLossFn (line 15) | class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
    method forward (line 17) | def forward(
    method backward (line 103) | def backward(ctx, grad_loss):
  class FusedCrossEntropyLoss (line 113) | class FusedCrossEntropyLoss(nn.Module):
    method __init__ (line 114) | def __init__(
    method forward (line 131) | def forward(self, input, target):

FILE: lit_gpt/fused_rotary_embedding.py
  class ApplyRotaryEmb (line 10) | class ApplyRotaryEmb(torch.autograd.Function):
    method forward (line 12) | def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
    method backward (line 56) | def backward(ctx, do):

FILE: lit_gpt/lora.py
  class LoRALayer (line 62) | class LoRALayer(nn.Module):
    method __init__ (line 63) | def __init__(self, r: int, lora_alpha: int, lora_dropout: float):
  class LoRALinear (line 87) | class LoRALinear(LoRALayer):
    method __init__ (line 89) | def __init__(
    method reset_parameters (line 128) | def reset_parameters(self):
    method merge (line 136) | def merge(self):
    method forward (line 143) | def forward(self, x: torch.Tensor):
  class LoRAQKVLinear (line 153) | class LoRAQKVLinear(LoRALinear):
    method __init__ (line 155) | def __init__(
    method zero_pad (line 256) | def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
    method conv1d (line 298) | def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.T...
    method merge (line 333) | def merge(self):
    method forward (line 351) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  function mark_only_lora_as_trainable (line 389) | def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") ->...
  function lora_filter (line 422) | def lora_filter(key: str, value: Any) -> bool:
  class Config (line 427) | class Config(BaseConfig):
    method mlp_class (line 450) | def mlp_class(self) -> Type:
  class GPT (line 454) | class GPT(BaseModel):
    method __init__ (line 455) | def __init__(self, config: Config) -> None:
    method forward (line 481) | def forward(
    method from_name (line 539) | def from_name(cls, name: str, **kwargs: Any) -> Self:
    method _init_weights (line 542) | def _init_weights(self, module: nn.Module) -> None:
    method _load_from_state_dict (line 548) | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: ...
  class Block (line 555) | class Block(BaseBlock):
    method __init__ (line 556) | def __init__(self, config: Config) -> None:
  class CausalSelfAttention (line 567) | class CausalSelfAttention(BaseCausalSelfAttention):
    method __init__ (line 568) | def __init__(self, config: Config) -> None:
    method _load_from_state_dict (line 604) | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: ...
  class GptNeoxMLP (line 616) | class GptNeoxMLP(lit_gpt.model.GptNeoxMLP):
    method __init__ (line 617) | def __init__(self, config: Config) -> None:
    method _load_from_state_dict (line 636) | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: ...
  class LLaMAMLP (line 648) | class LLaMAMLP(lit_gpt.model.LLaMAMLP):
    method __init__ (line 649) | def __init__(self, config: Config) -> None:
    method _load_from_state_dict (line 676) | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: ...
  function merge_lora_weights (line 690) | def merge_lora_weights(model: GPT) -> None:

FILE: lit_gpt/model.py
  class GPT (line 22) | class GPT(nn.Module):
    method __init__ (line 23) | def __init__(self, config: Config) -> None:
    method _init_weights (line 40) | def _init_weights(self, module: nn.Module, n_layer) -> None:
    method reset_cache (line 57) | def reset_cache(self) -> None:
    method forward (line 64) | def forward(
    method from_name (line 116) | def from_name(cls, name: str, **kwargs: Any) -> Self:
    method build_rope_cache (line 119) | def build_rope_cache(self, idx: torch.Tensor) -> RoPECache:
    method build_mask_cache (line 128) | def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor:
    method build_kv_caches (line 132) | def build_kv_caches(self, idx: torch.Tensor, max_seq_length: int, rope...
  class Block (line 150) | class Block(nn.Module):
    method __init__ (line 151) | def __init__(self, config: Config) -> None:
    method forward (line 159) | def forward(
  class CausalSelfAttention (line 186) | class CausalSelfAttention(nn.Module):
    method __init__ (line 187) | def __init__(self, config: Config) -> None:
    method forward (line 197) | def forward(
    method scaled_dot_product_attention (line 268) | def scaled_dot_product_attention(
  class GptNeoxMLP (line 294) | class GptNeoxMLP(nn.Module):
    method __init__ (line 295) | def __init__(self, config: Config) -> None:
    method forward (line 300) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class LLaMAMLP (line 306) | class LLaMAMLP(nn.Module):
    method __init__ (line 307) | def __init__(self, config: Config) -> None:
    method forward (line 313) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  function build_rope_cache (line 321) | def build_rope_cache(
  function apply_rope (line 350) | def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) ->...

FILE: lit_gpt/packed_dataset.py
  function code (line 16) | def code(dtype):
  class PackedDataset (line 27) | class PackedDataset(IterableDataset):
    method __init__ (line 28) | def __init__(
    method __iter__ (line 40) | def __iter__(self):
  class PackedDatasetBuilder (line 60) | class PackedDatasetBuilder(object):
    method __init__ (line 61) | def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto"...
    method _write_chunk (line 82) | def _write_chunk(self):
    method dtype (line 99) | def dtype(self):
    method filenames (line 103) | def filenames(self):
    method add_array (line 106) | def add_array(self, arr):
    method write_reminder (line 117) | def write_reminder(self):
  class PackedDatasetIterator (line 121) | class PackedDatasetIterator:
    method __init__ (line 122) | def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap):
    method _read_header (line 150) | def _read_header(self, path):
    method _close_mmaps (line 161) | def _close_mmaps(self):
    method _load_n_chunks (line 165) | def _load_n_chunks(self):
    method __del__ (line 192) | def __del__(self):
    method __iter__ (line 197) | def __iter__(self):
    method __next__ (line 200) | def __next__(self):
  class CombinedDataset (line 214) | class CombinedDataset(IterableDataset):
    method __init__ (line 215) | def __init__(self, datasets, seed, weights=None):
    method __iter__ (line 223) | def __iter__(self):
  class CombinedDatasetIterator (line 227) | class CombinedDatasetIterator:
    method __init__ (line 228) | def __init__(self, datasets, seed, weights):
    method __next__ (line 233) | def __next__(self):

FILE: lit_gpt/rmsnorm.py
  function maybe_align (line 10) | def maybe_align(x, alignment_in_bytes=16):
  function _dropout_add_layer_norm_forward (line 17) | def _dropout_add_layer_norm_forward(
  function _dropout_add_layer_norm_backward (line 56) | def _dropout_add_layer_norm_backward(
  function _dropout_add_layer_norm_subset_forward (line 111) | def _dropout_add_layer_norm_subset_forward(
  function _dropout_add_layer_norm_subset_backward (line 154) | def _dropout_add_layer_norm_subset_backward(
  function _dropout_add_layer_norm_parallel_residual_forward (line 213) | def _dropout_add_layer_norm_parallel_residual_forward(
  function _dropout_add_layer_norm_parallel_residual_backward (line 258) | def _dropout_add_layer_norm_parallel_residual_backward(
  class DropoutAddLayerNormFn (line 312) | class DropoutAddLayerNormFn(torch.autograd.Function):
    method forward (line 314) | def forward(
    method backward (line 375) | def backward(ctx, dz, *args):
  class DropoutAddLayerNormSubsetFn (line 417) | class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
    method forward (line 419) | def forward(
    method backward (line 484) | def backward(ctx, dz, *args):
  class DropoutAddLayerNormParallelResidualFn (line 532) | class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
    method forward (line 534) | def forward(
    method backward (line 606) | def backward(ctx, dz0, dz1, *args):
  function layer_norm (line 658) | def layer_norm(x, weight, bias, epsilon):
  function dropout_add_layer_norm (line 662) | def dropout_add_layer_norm(
  function dropout_add_layer_norm_subset (line 694) | def dropout_add_layer_norm_subset(
  function dropout_add_layer_norm_parallel_residual (line 732) | def dropout_add_layer_norm_parallel_residual(
  class DropoutAddLayerNorm (line 766) | class DropoutAddLayerNorm(torch.nn.Module):
    method __init__ (line 767) | def __init__(
    method reset_parameters (line 787) | def reset_parameters(self):
    method forward (line 791) | def forward(self, x0, residual=None):
  function rms_norm (line 803) | def rms_norm(x, weight, epsilon):
  class FusedRMSNorm (line 807) | class FusedRMSNorm(torch.nn.Module):
    method __init__ (line 808) | def __init__(self, size: int, dim: int = -1, eps: float = 1e-5):
    method reset_parameters (line 815) | def reset_parameters(self):
    method forward (line 818) | def forward(self, x):
  class RMSNorm (line 822) | class RMSNorm(torch.nn.Module):
    method __init__ (line 829) | def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
    method forward (line 835) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method reset_parameters (line 841) | def reset_parameters(self):

FILE: lit_gpt/speed_monitor.py
  function get_flops_available (line 74) | def get_flops_available(device: torch.device, precision: str) -> Optiona...
  class SpeedMonitorBase (line 122) | class SpeedMonitorBase:
    method __init__ (line 183) | def __init__(
    method on_train_batch_end (line 219) | def on_train_batch_end(
    method eval_end (line 297) | def eval_end(self, eval_elapsed: float):
  class SpeedMonitorFabric (line 301) | class SpeedMonitorFabric(SpeedMonitorBase):
    method __init__ (line 302) | def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None:
    method on_train_batch_end (line 308) | def on_train_batch_end(self, *args: Any, **kwargs: Any):
  class SpeedMonitorCallback (line 312) | class SpeedMonitorCallback(Callback):
    method __init__ (line 313) | def __init__(self, length_fn: Callable[[Any], int], batch_size: int, *...
    method setup (line 323) | def setup(self, trainer: Trainer, pl_module: LightningModule, stage: s...
    method on_train_start (line 333) | def on_train_start(self, trainer: Trainer, pl_module: LightningModule)...
    method on_train_batch_end (line 340) | def on_train_batch_end(
    method on_validation_start (line 360) | def on_validation_start(self, trainer: Trainer, pl_module: LightningMo...
    method on_validation_end (line 364) | def on_validation_end(self, trainer: Trainer, pl_module: LightningModu...
  function flops_per_param (line 370) | def flops_per_param(config: Config, n_params: int) -> int:
  function estimate_flops (line 379) | def estimate_flops(model: GPT) -> int:
  function measure_flops (line 401) | def measure_flops(model: GPT, x: torch.Tensor) -> int:

FILE: lit_gpt/tokenizer.py
  class Tokenizer (line 8) | class Tokenizer:
    method __init__ (line 9) | def __init__(self, checkpoint_dir: Path) -> None:
    method vocab_size (line 32) | def vocab_size(self) -> int:
    method token_to_id (line 39) | def token_to_id(self, token: str) -> int:
    method encode (line 50) | def encode(
    method decode (line 75) | def decode(self, tensor: torch.Tensor) -> str:

FILE: lit_gpt/utils.py
  function find_multiple (line 20) | def find_multiple(n: int, k: int) -> int:
  function num_parameters (line 27) | def num_parameters(module: nn.Module, requires_grad: Optional[bool] = No...
  function quantization (line 32) | def quantization(mode: Optional[str] = None):
  class NotYetLoadedTensor (line 94) | class NotYetLoadedTensor:
    method __init__ (line 95) | def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args):
    method rebuild_from_type_v2 (line 102) | def rebuild_from_type_v2(cls, func, new_type, args, state, *, archivei...
    method rebuild_parameter (line 116) | def rebuild_parameter(cls, data, requires_grad, backward_hooks, *, arc...
    method rebuild_tensor_v2 (line 129) | def rebuild_tensor_v2(
    method _load_tensor (line 139) | def _load_tensor(self):
    method __torch_function__ (line 156) | def __torch_function__(cls, func, types, args=(), kwargs=None):
    method __getattr__ (line 163) | def __getattr__(self, name):
    method __repr__ (line 190) | def __repr__(self):
  class LazyLoadingUnpickler (line 194) | class LazyLoadingUnpickler(pickle.Unpickler):
    method __init__ (line 195) | def __init__(self, file, zipfile_context):
    method find_class (line 199) | def find_class(self, module, name):
    method persistent_load (line 209) | def persistent_load(self, pid):
  class lazy_load (line 218) | class lazy_load:
    method __init__ (line 219) | def __init__(self, fn):
    method __enter__ (line 225) | def __enter__(self):
    method __exit__ (line 228) | def __exit__(self, exc_type, exc_val, exc_tb):
  function check_valid_checkpoint_dir (line 233) | def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None:
  class SavingProxyForStorage (line 267) | class SavingProxyForStorage:
    method __init__ (line 268) | def __init__(self, obj, saver, protocol_version=5):
    method __reduce_ex__ (line 291) | def __reduce_ex__(self, protocol_version):
  class SavingProxyForTensor (line 295) | class SavingProxyForTensor:
    method __init__ (line 296) | def __init__(self, tensor, saver, protocol_version=5):
    method __reduce_ex__ (line 303) | def __reduce_ex__(self, protocol_version):
  class IncrementalPyTorchPickler (line 309) | class IncrementalPyTorchPickler(pickle.Pickler):
    method __init__ (line 310) | def __init__(self, saver, *args, **kwargs):
    method persistent_id (line 317) | def persistent_id(self, obj):
  class incremental_save (line 365) | class incremental_save:
    method __init__ (line 366) | def __init__(self, name):
    method __enter__ (line 372) | def __enter__(self):
    method store_early (line 375) | def store_early(self, tensor):
    method save (line 380) | def save(self, obj):
    method _write_storage_and_return_key (line 391) | def _write_storage_and_return_key(self, storage):
    method __exit__ (line 403) | def __exit__(self, type, value, traceback):
  function step_csv_logger (line 410) | def step_csv_logger(*args: Any, cls: Type[T] = CSVLogger, **kwargs: Any)...
  function chunked_cross_entropy (line 440) | def chunked_cross_entropy(
  function map_old_state_dict_weights (line 482) | def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefi...
  function get_default_supported_precision (line 491) | def get_default_supported_precision(training: bool, tpu: bool = False) -...

FILE: pretrain/tinyllama.py
  function setup (line 78) | def setup(
  function main (line 110) | def main(fabric, train_data_dir, val_data_dir, resume):
  function train (line 165) | def train(fabric, state, train_dataloader, val_dataloader, monitor, resu...
  function validate (line 278) | def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: D...
  function create_dataloader (line 301) | def create_dataloader(
  function create_dataloaders (line 339) | def create_dataloaders(
  function get_lr (line 375) | def get_lr(it):

FILE: pretrain/tinyllama_code.py
  function setup (line 77) | def setup(
  function main (line 109) | def main(fabric, train_data_dir, val_data_dir, resume):
  function train (line 169) | def train(fabric, state, train_dataloader, val_dataloader, monitor, resu...
  function validate (line 282) | def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: D...
  function create_dataloader (line 305) | def create_dataloader(
  function create_dataloaders (line 343) | def create_dataloaders(
  function get_lr (line 379) | def get_lr(it):

FILE: scripts/convert_hf_checkpoint.py
  function copy_weights_gpt_neox (line 19) | def copy_weights_gpt_neox(
  function copy_weights_falcon (line 62) | def copy_weights_falcon(
  function copy_weights_hf_llama (line 111) | def copy_weights_hf_llama(
  function layer_template (line 173) | def layer_template(layer_name: str, idx: int) -> Tuple[str, int]:
  function load_param (line 181) | def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str...
  function convert_hf_checkpoint (line 193) | def convert_hf_checkpoint(

FILE: scripts/convert_lit_checkpoint.py
  function layer_template (line 20) | def layer_template(layer_name: str, idx: int) -> Tuple[str, int]:
  function load_param (line 28) | def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str...
  function copy_weights_falcon (line 37) | def copy_weights_falcon(
  function copy_weights_gpt_neox (line 85) | def copy_weights_gpt_neox(
  function copy_weights_llama (line 121) | def copy_weights_llama(
  function tensor_split (line 170) | def tensor_split(
  function maybe_unwrap_state_dict (line 215) | def maybe_unwrap_state_dict(lit_weights: Dict[str, torch.Tensor]) -> Dic...
  function check_conversion_supported (line 219) | def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> ...
  function get_tinyllama_init_hf_config (line 232) | def get_tinyllama_init_hf_config() -> dict:
  function convert_config_lit_to_hf (line 257) | def convert_config_lit_to_hf(lit_config_dict: dict) -> dict:
  function convert_lit_checkpoint (line 277) | def convert_lit_checkpoint(*,

FILE: scripts/prepare_redpajama.py
  function prepare_sample (line 42) | def prepare_sample(
  function prepare_full (line 86) | def prepare_full(
  function prepare (line 142) | def prepare(

FILE: scripts/prepare_slimpajama.py
  function prepare_full (line 26) | def prepare_full(
  function prepare (line 73) | def prepare(

FILE: scripts/prepare_starcoder.py
  function prepare_full (line 21) | def prepare_full(
  function prepare (line 69) | def prepare(

FILE: sft/finetune.py
  class ModelArguments (line 52) | class ModelArguments:
  class DataArguments (line 63) | class DataArguments:
  class TrainingArguments (line 99) | class TrainingArguments(transformers.Seq2SeqTrainingArguments):
  class GenerationArguments (line 133) | class GenerationArguments:
  function get_accelerate_model (line 168) | def get_accelerate_model(args, checkpoint_dir):
  function print_trainable_parameters (line 212) | def print_trainable_parameters(args, model):
  function smart_tokenizer_and_embedding_resize (line 227) | def smart_tokenizer_and_embedding_resize(
  class DataCollatorForCausalLM (line 252) | class DataCollatorForCausalLM(object):
    method __call__ (line 259) | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
  function extract_unnatural_instructions_data (line 304) | def extract_unnatural_instructions_data(examples, extract_reformulations...
  function extract_alpaca_dataset (line 334) | def extract_alpaca_dataset(example):
  function local_dataset (line 341) | def local_dataset(dataset_name):
  function make_data_module (line 354) | def make_data_module(tokenizer: transformers.PreTrainedTokenizer, args) ...
  function get_last_checkpoint (line 478) | def get_last_checkpoint(checkpoint_dir):
  function train (line 492) | def train():
Condensed preview — 37 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (338K chars).
[
  {
    "path": ".gitignore",
    "chars": 144,
    "preview": "__pycache__\n.idea\n.DS_Store\n*.egg-info\nbuild\n.venv\n.vscode\n\n# data\ndata\ncheckpoints\nout\nwandb\n\ntests/original_falcon_40b"
  },
  {
    "path": "EVAL.md",
    "chars": 4818,
    "preview": "## Evaluate TinyLlama\n\n### GPT4All Benchmarks\n\nWe evaluate TinyLlama's commonsense reasoning ability following the [GPT4"
  },
  {
    "path": "LICENSE",
    "chars": 11344,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "PRETRAIN.md",
    "chars": 2973,
    "preview": "## Pretrain TinyLlama\n\n### Installation\nWe expect you have CUDA 11.8 installed.\n#### Install Pytorch Nightly.\n```bash\npi"
  },
  {
    "path": "README.md",
    "chars": 14101,
    "preview": "<div align=\"center\">\n\n# TinyLlama-1.1B\nEnglish | [中文](README_zh-CN.md)\n\n[Chat Demo](https://huggingface.co/spaces/TinyLl"
  },
  {
    "path": "README_zh-CN.md",
    "chars": 10455,
    "preview": "<div align=\"center\">\n\n# TinyLlama-1.1B\n[English](README.md) | 中文\n\n[Chat Demo](https://huggingface.co/spaces/TinyLlama/ti"
  },
  {
    "path": "chat_gradio/README.md",
    "chars": 1060,
    "preview": "## Tinyllama Chatbot Implementation with Gradio\n\nWe offer an easy way to interact with Tinyllama. This guide explains ho"
  },
  {
    "path": "chat_gradio/app.py",
    "chars": 2537,
    "preview": "import gradio as gr\nimport torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom transformers import S"
  },
  {
    "path": "chat_gradio/requirements.txt",
    "chars": 47,
    "preview": "torch>=2.0\ntransformers>=4.35.0\ngradio>=4.13.0\n"
  },
  {
    "path": "lit_gpt/__init__.py",
    "chars": 808,
    "preview": "from lit_gpt.model import GPT\nfrom lit_gpt.config import Config\nfrom lit_gpt.tokenizer import Tokenizer\nfrom lit_gpt.fus"
  },
  {
    "path": "lit_gpt/adapter.py",
    "chars": 12122,
    "preview": "\"\"\"Implementation of the paper:\n\nLLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention\nhttps:"
  },
  {
    "path": "lit_gpt/adapter_v2.py",
    "chars": 12807,
    "preview": "\"\"\"Implementation of the paper:\n\nLLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model\nhttps://arxiv.org/abs/23"
  },
  {
    "path": "lit_gpt/config.py",
    "chars": 20125,
    "preview": "from dataclasses import dataclass\nfrom typing import Any, Literal, Optional, Type\n\nimport torch\nfrom typing_extensions i"
  },
  {
    "path": "lit_gpt/fused_cross_entropy.py",
    "chars": 6360,
    "preview": "# Copyright (c) 2023, Tri Dao.\n\nimport torch\nimport torch.nn as nn\nimport xentropy_cuda_lib\n\n# `all_gather_into_tensor` "
  },
  {
    "path": "lit_gpt/fused_rotary_embedding.py",
    "chars": 3021,
    "preview": "# Copyright (c) 2023, Tri Dao.\n\nimport math\nfrom typing import Optional, Tuple\n\nimport rotary_emb\nimport torch\nfrom eino"
  },
  {
    "path": "lit_gpt/lora.py",
    "chars": 31609,
    "preview": "# Derived from https://github.com/microsoft/LoRA\n#  --------------------------------------------------------------------"
  },
  {
    "path": "lit_gpt/model.py",
    "chars": 15495,
    "preview": "\"\"\"Full definition of a GPT NeoX Language Model, all of it in this single file.\n\nBased on the nanoGPT implementation: ht"
  },
  {
    "path": "lit_gpt/packed_dataset.py",
    "chars": 7664,
    "preview": "# Very loosely inspired by indexed_dataset in Fairseq, Megatron\n# https://github.com/NVIDIA/Megatron-LM/blob/main/megatr"
  },
  {
    "path": "lit_gpt/rmsnorm.py",
    "chars": 23995,
    "preview": "import torch\n# Copyright (c) 2022, Tri Dao.\n# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer"
  },
  {
    "path": "lit_gpt/speed_monitor.py",
    "chars": 20777,
    "preview": "import time\nfrom collections import deque\nfrom contextlib import nullcontext\nfrom typing import Any, Callable, Deque, Di"
  },
  {
    "path": "lit_gpt/tokenizer.py",
    "chars": 2913,
    "preview": "import json\nfrom pathlib import Path\nfrom typing import Optional\n\nimport torch\n\n\nclass Tokenizer:\n    def __init__(self,"
  },
  {
    "path": "lit_gpt/utils.py",
    "chars": 19417,
    "preview": "\"\"\"Utility functions for training and inference.\"\"\"\n\nimport pickle\nimport sys\nimport warnings\nfrom contextlib import con"
  },
  {
    "path": "pretrain/tinyllama.py",
    "chars": 14901,
    "preview": "import glob\nimport math\nimport sys\nimport time\nfrom pathlib import Path\nfrom typing import Optional, Tuple, Union\nimport"
  },
  {
    "path": "pretrain/tinyllama_code.py",
    "chars": 15139,
    "preview": "import glob\nimport math\nimport sys\nimport time\nfrom pathlib import Path\nfrom typing import Optional, Tuple, Union\nimport"
  },
  {
    "path": "requirements.txt",
    "chars": 711,
    "preview": "torch>=2.1.0dev\nlightning==2.1.2\nlightning[app]\njsonargparse[signatures]  # CLI\npandas\npyarrow\ntokenizers\nsentencepiece\n"
  },
  {
    "path": "script.sh",
    "chars": 838,
    "preview": "python scripts/convert_hf_checkpoint.py --checkpoint_dir  out/TinyLlama-1.1B-900B --model_name tiny_LLaMA_1b\n\npython tes"
  },
  {
    "path": "scripts/convert_hf_checkpoint.py",
    "chars": 10503,
    "preview": "import contextlib\nimport gc\nimport json\nimport sys\nfrom functools import partial\nfrom pathlib import Path\nfrom typing im"
  },
  {
    "path": "scripts/convert_lit_checkpoint.py",
    "chars": 12612,
    "preview": "import contextlib\nimport gc\nimport sys\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Dict, L"
  },
  {
    "path": "scripts/prepare_redpajama.py",
    "chars": 5457,
    "preview": "import glob\nimport json\nimport os\nimport sys\nfrom pathlib import Path\n\nimport numpy as np\nfrom tqdm import tqdm\n\n# suppo"
  },
  {
    "path": "scripts/prepare_slimpajama.py",
    "chars": 3297,
    "preview": "import json\nimport glob\nimport os\nfrom pathlib import Path\nimport sys\nfrom typing import List\nimport numpy as np\nfrom tq"
  },
  {
    "path": "scripts/prepare_starcoder.py",
    "chars": 3297,
    "preview": "import json\nimport glob\nimport os\nfrom pathlib import Path\nimport sys\nfrom typing import List\nimport numpy as np\nfrom tq"
  },
  {
    "path": "sft/finetune.py",
    "chars": 23990,
    "preview": "# This source code is licensed under the MIT license found in the\n# LICENSE file in the root directory of this source tr"
  },
  {
    "path": "sft/script.sh",
    "chars": 2703,
    "preview": "# We include a simple full-parameter finetuning & inference script here. Our V0.1 chat model is finetuned using this scr"
  },
  {
    "path": "sft/simple_inference.py",
    "chars": 657,
    "preview": "from transformers import AutoTokenizer\nimport transformers \nimport torch\nmodel = \"PY007/TinyLlama-1.1B-Chat-v0.1\"\ntokeni"
  },
  {
    "path": "sft/simple_inference2.py",
    "chars": 679,
    "preview": "\n\nfrom transformers import AutoTokenizer\nimport transformers \nimport torch\nmodel = \"PY007/TinyLlama-1.1B-Chat-v0.2\"\ntoke"
  },
  {
    "path": "speculative_decoding/README.md",
    "chars": 4670,
    "preview": "## Speculative Decoding\n\n### HuggingFace \"Assisted Generation\"\n\n\n| Large Model | Native Decoding | Assisted Decoding  |\n"
  },
  {
    "path": "speculative_decoding/instruct_hf_assisted_decoding.py",
    "chars": 1386,
    "preview": "from transformers import AutoModelForCausalLM, AutoTokenizer\nimport torch\nimport time\n\n\nmodel_id = \"huggyllama/llama-13b"
  }
]

About this extraction

This page contains the full source code of the jzhang38/TinyLlama GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 37 files (317.8 KB), approximately 84.0k tokens, and a symbol index with 296 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!