Repository: jzhang38/TinyLlama Branch: main Commit: bf122247c486 Files: 37 Total size: 317.8 KB Directory structure: gitextract_67kobddf/ ├── .gitignore ├── EVAL.md ├── LICENSE ├── PRETRAIN.md ├── README.md ├── README_zh-CN.md ├── chat_gradio/ │ ├── README.md │ ├── app.py │ └── requirements.txt ├── lit_gpt/ │ ├── __init__.py │ ├── adapter.py │ ├── adapter_v2.py │ ├── config.py │ ├── fused_cross_entropy.py │ ├── fused_rotary_embedding.py │ ├── lora.py │ ├── model.py │ ├── packed_dataset.py │ ├── rmsnorm.py │ ├── speed_monitor.py │ ├── tokenizer.py │ └── utils.py ├── pretrain/ │ ├── tinyllama.py │ └── tinyllama_code.py ├── requirements.txt ├── script.sh ├── scripts/ │ ├── convert_hf_checkpoint.py │ ├── convert_lit_checkpoint.py │ ├── prepare_redpajama.py │ ├── prepare_slimpajama.py │ └── prepare_starcoder.py ├── sft/ │ ├── finetune.py │ ├── script.sh │ ├── simple_inference.py │ └── simple_inference2.py └── speculative_decoding/ ├── README.md └── instruct_hf_assisted_decoding.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ __pycache__ .idea .DS_Store *.egg-info build .venv .vscode # data data checkpoints out wandb tests/original_falcon_40b.py sft/output sft/wandb ================================================ FILE: EVAL.md ================================================ ## Evaluate TinyLlama ### GPT4All Benchmarks We evaluate TinyLlama's commonsense reasoning ability following the [GPT4All](https://gpt4all.io/index.html) evaluation suite. We include Pythia as our baseline. We report the acc_norm by default. Base models: | Model | Pretrain Tokens | HellaSwag | Obqa | WinoGrande | ARC_c | ARC_e | boolq | piqa | avg | |-------------------------------------------|-----------------|-----------|------|------------|-------|-------|-------|------|-----| | Pythia-1.0B | 300B | 47.16 | 31.40| 53.43 | 27.05 | 48.99 | 60.83 | 69.21 | 48.30 | | TinyLlama-1.1B-intermediate-step-50K-104b | 103B | 43.50 | 29.80| 53.28 | 24.32 | 44.91 | 59.66 | 67.30 | 46.11| | TinyLlama-1.1B-intermediate-step-240k-503b| 503B | 49.56 |31.40 |55.80 |26.54 |48.32 |56.91 |69.42 | 48.28 | | TinyLlama-1.1B-intermediate-step-480k-1007B | 1007B | 52.54 | 33.40 | 55.96 | 27.82 | 52.36 | 59.54 | 69.91 | 50.22 | | TinyLlama-1.1B-intermediate-step-715k-1.5T | 1.5T | 53.68 | 35.20 | 58.33 | 29.18 | 51.89 | 59.08 | 71.65 | 51.29 | | TinyLlama-1.1B-intermediate-step-955k-2T | 2T | 54.63 | 33.40 | 56.83 | 28.07 | 54.67 | 63.21 | 70.67 | 51.64 | | TinyLlama-1.1B-intermediate-step-1195k-2.5T | 2.5T | 58.96 | 34.40 | 58.72 | 31.91 | 56.78 | 63.21 | 73.07 | 53.86| | TinyLlama-1.1B-intermediate-step-1431k-3T | 3T | 59.20 | 36.00 | 59.12 | 30.12 | 55.25 | 57.83 | 73.29 | 52.99| Chat models: | Model | Pretrain Tokens | HellaSwag | Obqa | WinoGrande | ARC_c | ARC_e | boolq | piqa | avg | |-------------------------------------------|-----------------|-----------|------|------------|-------|-------|-------|------|-----| | [TinyLlama-1.1B-Chat-v0.1](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.1) | 503B | 53.81 |32.20 | 55.01 | 28.67 |49.62 | 58.04 | 69.64 | 49.57 | | [TinyLlama-1.1B-Chat-v0.2](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.2) | 503B | 53.63 |32.80 | 54.85 | 28.75 |49.16 | 55.72 | 69.48 | 49.20 | | [TinyLlama-1.1B-Chat-v0.3](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.3) | 1T | 56.81 |34.20 | 55.80 | 30.03 |53.20 | 59.57 | 69.91 | 51.36 | | [TinyLlama-1.1B-Chat-v0.4](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4) | 1.5T | 58.59 |35.40 | 58.80 | 30.80 |54.04 | 57.31 | 71.16 | 52.30 | We observed huge improvements once we finetuned the model. We attribute this phenomenon to: 1. the base model has not undergone lr cool-down and FT helps to cool down the lr. 2. the SFT stage better elicits the model's internal knowledge. You can obtain the above scores by running [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness): ```bash python main.py \ --model hf-causal \ --model_args pretrained=PY007/TinyLlama-1.1B-Chat-v0.1,dtype="float" \ --tasks hellaswag,openbookqa,winogrande,arc_easy,arc_challenge,boolq,piqa\ --device cuda:0 --batch_size 32 ``` ### Instruct-Eval Benchmarks We evaluate TinyLlama's ability in problem-solving on the [Instruct-Eval](https://github.com/declare-lab/instruct-eval) evaluation suite. | Model | MMLU | BBH | HumanEval | DROP | | ------------------------------------------------- | ----- | ----- | --------- | ----- | | Pythia-1.0B | 25.70 | 28.19 | 1.83 | 4.25 | | TinyLlama-1.1B-intermediate-step-50K-104b | 26.45 | 28.82 | 5.49 | 11.42 | | TinyLlama-1.1B-intermediate-step-240k-503b | 26.16 | 28.83 | 4.88 | 12.43 | | TinyLlama-1.1B-intermediate-step-480K-1T | 24.65 | 29.21 | 6.1 | 13.03 | | TinyLlama-1.1B-intermediate-step-715k-1.5T | 24.85 | 28.2 | 7.93 | 14.43 | | TinyLlama-1.1B-intermediate-step-955k-2T | 25.97 | 29.07 | 6.71 | 13.14 | | TinyLlama-1.1B-intermediate-step-1195k-token-2.5T | 25.92 | 29.32 | 9.15 | 15.45 | You can obtain above scores by running [instruct-eval](https://github.com/declare-lab/instruct-eval): ```bash CUDA_VISIBLE_DEVICES=0 python main.py mmlu --model_name llama --model_path PY007/TinyLlama-1.1B-intermediate-step-480K-1T CUDA_VISIBLE_DEVICES=1 python main.py bbh --model_name llama --model_path PY007/TinyLlama-1.1B-intermediate-step-480K-1T CUDA_VISIBLE_DEVICES=2 python main.py drop --model_name llama --model_path PY007/TinyLlama-1.1B-intermediate-step-480K-1T CUDA_VISIBLE_DEVICES=3 python main.py humaneval --model_name llama --n_sample 1 --model_path PY007/TinyLlama-1.1B-intermediate-step-480K-1T ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [2023] Lightning AI Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: PRETRAIN.md ================================================ ## Pretrain TinyLlama ### Installation We expect you have CUDA 11.8 installed. #### Install Pytorch Nightly. ```bash pip install --index-url https://download.pytorch.org/whl/nightly/cu118 --pre 'torch>=2.1.0dev' ``` #### Build XFormers from Source Note: as of 2023/09/02, xformers does not provide pre-built binaries for torch 2.1. You have to build it from source. ```bash pip uninstall ninja -y && pip install ninja -U pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers ``` #### Install Flash-Attention 2 and other fused operators: ```bash git clone https://github.com/Dao-AILab/flash-attention cd flash-attention python setup.py install cd csrc/rotary && pip install . cd ../layer_norm && pip install . cd ../xentropy && pip install . cd ../.. && rm -rf flash-attention ``` #### Install Remaining Dependencies ``` pip install -r requirements.txt tokenizers sentencepiece ``` to install other dependencies. It may take >= 5 minutes to build xformers/flash-attention. Do not worry if the process seemly stagnant or the terminal print out many warnings. Then you are ready to go 🎉! ### Data Preparation #### Download Datasets Download the Slimpajama and Starcoderdata datasets to your chosen directory. ```bash cd /path/to/dataset git lfs install git clone https://huggingface.co/datasets/cerebras/SlimPajama-627B git clone https://huggingface.co/datasets/bigcode/starcoderdata ``` The SlimPajama dataset eats 893GB diskspace and the starcoderdata takes 290GB. #### Tokenize data Use the provided scripts to tokenize the datasets and divide them into chunks. ```bash python scripts/prepare_starcoder.py --source_path /path/to/starcoderdata/ --tokenizer_path data/llama --destination_path data/slim_star_combined --split train --percentage 1.0 python scripts/prepare_slimpajama.py --source_path /path/to/SlimPajama --tokenizer_path data/llama --destination_path data/slim_star_combined --split validation --percentage 1.0 python scripts/prepare_slimpajama.py --source_path /path/to/SlimPajama --tokenizer_path data/llama --destination_path data/slim_star_combined --split train --percentage 1.0 ``` The processed data will take 1.8T storage. ### Pretraining If your setup comprises two nodes, each with 8 GPUs, you can initiate pretraining with the following commands: On node 1: ``` lightning run model \ --node-rank=0 \ --main-address=172.16.101.5 \ --accelerator=cuda \ --devices=8 \ --num-nodes=2 \ pretrain/tinyllama.py --devices 8 --train_data_dir data/slim_star --val_data_dir data/slim_star ``` On node 2: ``` lightning run model \ --node-rank=1 \ --main-address=172.16.101.5 \ --accelerator=cuda \ --devices=8 \ --num-nodes=2 \ pretrain/tinyllama.py --devices 8 --train_data_dir data/slim_star --val_data_dir data/slim_star ``` You can follow [these instructions](https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html) if you have a slurm cluster. ================================================ FILE: README.md ================================================
Above is the training loss curve taken from the Llama 2 paper. Here I quote from that paper: "We observe that after pretraining on 2T Tokens, the models still did not show any sign of saturation". That is why we believe pretraining a 1.1B model for 3T tokens is a reasonable thing to do. Even if the loss curve does not go down eventually, we can still study the phenomenon of saturation and learn something from it.
#### 2. What does "saturation" mean?
The figure from the Pythia paper displays the LAMBADA accuracy plotted against the total training tokens (300B). The term "saturation" pertains specifically to the 70M and 160M models. Notably, even the 410M model does not saturate with 300B tokens, as it continues to show an increasing trend, similar to the trend of larger models.
## Star History
[](https://star-history.com/#jzhang38/TinyLlama&Date)
================================================
FILE: README_zh-CN.md
================================================
self._chunk_size: part_len = self._chunk_size - self._idx self._arr[self._idx : self._idx + part_len] = arr[:part_len] self._write_chunk() arr = arr[part_len:] arr_len = arr.shape[0] self._arr[self._idx : self._idx + arr_len] = arr self._idx += arr_len def write_reminder(self): self._write_chunk() class PackedDatasetIterator: def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap): self._seed = seed self._shuffle = shuffle self._rng = np.random.default_rng(seed) if shuffle else None self._block_idxs = None self._wrap = wrap # TODO: instead of filenames, we could have a single text stream # (or text file) with the sequence of all files to be # fetched/loaded. self._filenames = filenames self._file_idx = 0 self._n_chunks = n_chunks self._dtype = None self._block_size = block_size self._n_blocks = None self._mmaps = [] self._buffers = [] self._block_idxs = [] self._curr_idx = 0 self._load_n_chunks() def _read_header(self, path): with open(path, "rb") as f: magic = f.read(len(HDR_MAGIC)) assert magic == HDR_MAGIC, "File doesn't match expected format." version = struct.unpack("len(self._filenames[self._file_idx :]): # if not self._wrap: # raise StopIteration self._file_idx = 0 for i in range(self._n_chunks): filename = self._filenames[self._file_idx + i] if self._dtype is None: self._dtype, self._chunk_size = self._read_header(filename) self._n_blocks = self._chunk_size // self._block_size # TODO: check header matches with previous files mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) self._mmaps.append(mmap) self._buffers.append(memoryview(mmap)) self._file_idx += self._n_chunks n_all_blocks = self._n_chunks * self._n_blocks self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks) self._curr_idx = 0 def __del__(self): self._close_mmaps() del self._mmaps del self._buffers def __iter__(self): return self def __next__(self): if self._curr_idx >= len(self._block_idxs): self._load_n_chunks() # TODO: trigger fetching next next n_chunks if remote block_idx = self._block_idxs[self._curr_idx] chunk_id = block_idx // self._n_blocks buffer = self._buffers[chunk_id] elem_id = (block_idx % self._n_blocks) * self._block_size offset = np.dtype(self._dtype).itemsize * elem_id arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) self._curr_idx += 1 return torch.from_numpy(arr.astype(np.int64)) class CombinedDataset(IterableDataset): def __init__(self, datasets, seed, weights=None): self._seed = seed self._datasets = datasets self._weights = weights n_datasets = len(datasets) if weights is None: self._weights = [1 / n_datasets] * n_datasets def __iter__(self): return CombinedDatasetIterator(self._datasets, self._seed, self._weights) class CombinedDatasetIterator: def __init__(self, datasets, seed, weights): self._datasets = [iter(el) for el in datasets] self._weights = weights self._rng = random.Random(seed) def __next__(self): (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1) return next(dataset) ================================================ FILE: lit_gpt/rmsnorm.py ================================================ import torch # Copyright (c) 2022, Tri Dao. # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py AND https://github.com/Dao-AILab/flash-attention/blob/7a983df74215e035e566e37125b0a71e3618f39d/flash_attn/ops/layer_norm.py#L16 import dropout_layer_norm import torch from torch.nn import init def maybe_align(x, alignment_in_bytes=16): """Assume that x already has last dim divisible by alignment_in_bytes""" # TD [2023-07-04] I'm not 100% sure that clone will align the memory # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440 return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() def _dropout_add_layer_norm_forward( x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32=False, is_rms_norm=False, ): """Assume that arguments are contiguous and aligned to 16 bytes""" hidden_size = gamma.numel() x0mat = x0.view((-1, hidden_size)) residualmat = residual.view((-1, hidden_size)) if residual is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon, 1.0, 0, None, residual_in_fp32, is_rms_norm, ) # dmask is None if dropout_p == 0.0 # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma def _dropout_add_layer_norm_backward( dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual, is_rms_norm=False, ): """Assume that arguments are contiguous and aligned to 16 bytes dx == None means that it was a post-norm architecture (x = drop(x0) + residual was not returned in the fwd). x0 must not be None if we have colscale. """ hidden_size = gamma.numel() xmat = x.view((-1, hidden_size)) dzmat = dz.view(xmat.shape) dxmat = dx.view(xmat.shape) if dx is not None else None x0mat = x0.view((-1, hidden_size)) if x0 is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None if colscale is not None: assert x0 is not None, "x0 is required to compute the gradient of colscale" dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None, dropout_p, 1.0, 0, has_residual, is_rms_norm, ) # dresidualmat is None if not has_residual if colscale is None: return dx0mat, dresidualmat, dgamma, dbeta else: dcolscale = rest[0] return dx0mat, dresidualmat, dgamma, dbeta, dcolscale def _dropout_add_layer_norm_subset_forward( x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, residual_in_fp32=False, is_rms_norm=False, ): """Assume that arguments are contiguous and aligned to 16 bytes""" hidden_size = gamma.numel() x0mat = x0.view((-1, hidden_size)) residualmat = residual.view((-1, hidden_size)) if residual is not None else None x0_subset = x0_subset.view(-1) if x0_subset is not None else None out_subset = out_subset.view(-1) if out_subset is not None else None zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm, ) # dmask is None if dropout_p == 0.0 # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma def _dropout_add_layer_norm_subset_backward( dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm=False, ): """Assume that arguments are contiguous and aligned to 16 bytes dx == None means that it was a post-norm architecture (x = drop(x0) + residual was not returned in the fwd). x0 must not be None if we have colscale. """ hidden_size = gamma.numel() xmat = x.view((-1, hidden_size)) dzmat = dz.view(-1, hidden_size) dxmat = dx.view(xmat.shape) if dx is not None else None x0mat = x0.view((-1, hidden_size)) if x0 is not None else None x0_subset = x0_subset.view(-1) if x0_subset is not None else None out_subset = out_subset.view(-1) if out_subset is not None else None if colscale is not None: assert x0 is not None, "x0 is required to compute the gradient of colscale" dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm, ) # dresidualmat is None if not has_residual if colscale is None: return dx0mat, dresidualmat, dgamma, dbeta else: dcolscale = rest[0] return dx0mat, dresidualmat, dgamma, dbeta, dcolscale def _dropout_add_layer_norm_parallel_residual_forward( x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, residual_in_fp32=False, is_rms_norm=False, ): """Assume that arguments are contiguous and aligned to 16 bytes""" hidden_size = gamma0.numel() x0mat = x0.view((-1, hidden_size)) x1mat = x1.view((-1, hidden_size)) if x1 is not None else None residualmat = residual.view((-1, hidden_size)) if residual is not None else None ( z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma, ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( x0mat, x1mat, residualmat, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, None, residual_in_fp32, is_rms_norm, ) # dmask0 and dmask1 are None if dropout_p == 0.0 # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma def _dropout_add_layer_norm_parallel_residual_backward( dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm=False, ): """Assume that arguments are contiguous and aligned to 16 bytes dx == None means that it was a post-norm architecture (x = drop(x0) + residual was not returned in the fwd). """ hidden_size = gamma0.numel() xmat = x.view((-1, hidden_size)) dz0mat = dz0.view(xmat.shape) dz1mat = dz1.view(xmat.shape) if dz1 is not None else None dxmat = dx.view(xmat.shape) if dx is not None else None ( dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1, *rest, ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( dz0mat, dz1mat, dxmat, xmat, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm, ) # dresidualmat is None if not has_residual return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 class DropoutAddLayerNormFn(torch.autograd.Function): @staticmethod def forward( ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False, ): x0 = maybe_align(x0.contiguous(), 16) residual = maybe_align(residual.contiguous(), 16) if residual is not None else None gamma = maybe_align(gamma.contiguous(), 16) beta = maybe_align(beta.contiguous(), 16) if beta is not None else None rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32, is_rms_norm, ) # Only need to save x0 if we need to compute gradient wrt colscale x0_saved = x0 if colscale is not None else None ctx.save_for_backward( xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale ) ctx.prenorm = prenorm ctx.dropout_p = dropout_p ctx.has_residual = residual is not None ctx.is_rms_norm = is_rms_norm ctx.has_beta = beta is not None if not return_dmask: return ( zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape)) ) else: dmask = ( dmask.view(x0.shape) if dropout_p > 0.0 else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) ) ctx.mark_non_differentiable(dmask) return ( (zmat.view(x0.shape), dmask) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape), dmask) ) @staticmethod def backward(ctx, dz, *args): # assert dz.is_contiguous() dz = maybe_align(dz.contiguous(), 16) # this happens! dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors # x0 is None if colscale is None dropout_p = ctx.dropout_p has_residual = ctx.has_residual dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual, ctx.is_rms_norm, ) dx0 = dx0mat.view(x.shape) dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None dcolscale = rest[0] if colscale is not None else None return ( dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None, None, None, None, None, None, ) class DropoutAddLayerNormSubsetFn(torch.autograd.Function): @staticmethod def forward( ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False, ): x0 = maybe_align(x0.contiguous(), 16) residual = maybe_align(residual.contiguous(), 16) if residual is not None else None gamma = maybe_align(gamma.contiguous(), 16) beta = maybe_align(beta.contiguous(), 16) if beta is not None else None colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, residual_in_fp32, is_rms_norm, ) # Only need to save x0 if we need to compute gradient wrt colscale x0_saved = x0 if colscale is not None else None x_shape = (-1, *x0.shape[1:]) ctx.save_for_backward( xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset ) ctx.prenorm = prenorm ctx.dropout_p = dropout_p ctx.rowscale_const = rowscale_const ctx.x0_numrows = x0.shape[:-1].numel() ctx.has_residual = residual is not None ctx.is_rms_norm = is_rms_norm ctx.has_beta = beta is not None z_shape = (-1, *x0.shape[1:]) if not return_dmask: return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape)) else: z = zmat.view(z_shape) dmask = ( dmask.view(x0.shape) if dropout_p > 0.0 else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) ) ctx.mark_non_differentiable(dmask) return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask) @staticmethod def backward(ctx, dz, *args): # assert dz.is_contiguous() dz = maybe_align(dz.contiguous(), 16) # this happens! dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors # x0 is None if colscale is None dropout_p = ctx.dropout_p has_residual = ctx.has_residual dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p, ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm, ) dx0 = dx0mat.view(-1, *x.shape[1:]) dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None dcolscale = rest[0] if colscale is not None else None return ( dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None, None, None, None, None, None, None, None, None, ) class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): @staticmethod def forward( ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False, ): x0 = maybe_align(x0.contiguous(), 16) x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None residual = maybe_align(residual.contiguous(), 16) if residual is not None else None gamma0 = maybe_align(gamma0.contiguous(), 16) beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None ( z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma, ) = _dropout_add_layer_norm_parallel_residual_forward( x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, residual_in_fp32, is_rms_norm, ) ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma) ctx.prenorm = prenorm ctx.dropout_p = dropout_p ctx.has_x1 = x1 is not None ctx.has_residual = residual is not None ctx.is_rms_norm = is_rms_norm ctx.has_beta = beta0 is not None z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None) if not return_dmask: return z if not prenorm else (*z, xmat.view(x0.shape)) else: dmask0 = ( dmask0.view(x0.shape) if dropout_p > 0.0 else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) ) dmask1 = ( dmask1.view(x0.shape) if dropout_p > 0.0 and x1 is not None else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) ) ctx.mark_non_differentiable(dmask0) ctx.mark_non_differentiable(dmask1) return ( (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1) ) @staticmethod def backward(ctx, dz0, dz1, *args): dz0 = maybe_align(dz0.contiguous(), 16) # this happens! dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors dropout_p = ctx.dropout_p has_x1 = ctx.has_x1 has_residual = ctx.has_residual ( dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1, ) = _dropout_add_layer_norm_parallel_residual_backward( dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, ctx.is_rms_norm, ) dx0 = dx0mat.view(x.shape) dx1 = dx1mat.view(x.shape) if dx1mat is not None else None dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None return ( dx0, dx1, dresidual, dgamma0, dbeta0 if ctx.has_beta else None, dgamma1, dbeta1 if ctx.has_beta else None, None, None, None, None, None, None, ) def layer_norm(x, weight, bias, epsilon): return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) def dropout_add_layer_norm( x0, residual, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None, prenorm=False, residual_in_fp32=False, return_dropout_mask=False, ): """residual_in_fp32 only has an effect if residual is None. Otherwise residual dtype is residual.dtype. """ return DropoutAddLayerNormFn.apply( x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm, False, return_dropout_mask, ) def dropout_add_layer_norm_subset( x0, residual, weight, bias, dropout_p, epsilon, layerscale=None, x0_subset=None, out_subset=None, rowscale_const=1.0, out_numrows=0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False, ): """residual_in_fp32 only has an effect if residual is None. Otherwise residual dtype is residual.dtype. """ return DropoutAddLayerNormSubsetFn.apply( x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask, ) def dropout_add_layer_norm_parallel_residual( x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, prenorm=False, residual_in_fp32=False, return_dropout_mask=False, ): """residual_in_fp32 only has an effect if residual is None. Otherwise residual dtype is residual.dtype. """ return DropoutAddLayerNormParallelResidualFn.apply( x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm, False, return_dropout_mask, ) class DropoutAddLayerNorm(torch.nn.Module): def __init__( self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.prenorm = prenorm self.p = p self.eps = eps self.residual_in_fp32 = residual_in_fp32 self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.reset_parameters() def reset_parameters(self): init.ones_(self.weight) init.zeros_(self.bias) def forward(self, x0, residual=None): return dropout_add_layer_norm( x0, residual, self.weight, self.bias, self.p if self.training else 0.0, self.eps, prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32, ) def rms_norm(x, weight, epsilon): return DropoutAddLayerNormFn.apply( x, None, weight, None, None, None, 0.0, epsilon, False, False, True ) class FusedRMSNorm(torch.nn.Module): def __init__(self, size: int, dim: int = -1, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.ones(size)) self.dim = dim self.reset_parameters() def reset_parameters(self): init.ones_(self.weight) def forward(self, x): return rms_norm(x, self.weight, self.eps) class RMSNorm(torch.nn.Module): """Root Mean Square Layer Normalization. Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. """ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: super().__init__() self.weight = torch.nn.Parameter(torch.ones(size)) self.eps = eps self.dim = dim def forward(self, x: torch.Tensor) -> torch.Tensor: # NOTE: the original RMSNorm paper implementation is not equivalent norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) x_normed = x * torch.rsqrt(norm_x + self.eps) return self.weight * x_normed def reset_parameters(self): torch.nn.init.ones_(self.weight) ================================================ FILE: lit_gpt/speed_monitor.py ================================================ import time from collections import deque from contextlib import nullcontext from typing import Any, Callable, Deque, Dict, Optional import torch from lightning import Callback, Fabric, LightningModule, Trainer from lightning.fabric.utilities.rank_zero import rank_zero_only as fabric_rank_zero_only from lightning.pytorch.utilities.rank_zero import rank_zero_only as trainer_rank_zero_only from torch.utils.flop_counter import FlopCounterMode import math from lit_gpt import GPT, Config from lit_gpt.utils import num_parameters GPU_AVAILABLE_FLOPS = { # source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet # nvidia publishes spec sheet with a 2x sparsity factor "h100-sxm": { "64-true": 67e12, "32-true": 67e12, "16-true": 1.979e15 / 2, "16-mixed": 1.979e15 / 2, "bf16-true": 1.979e15 / 2, "bf16-mixed": 1.979e15 / 2, "8-true": 3.958e15 / 2, "8-mixed": 3.958e15 / 2, }, "h100-pcie": { "64-true": 51e12, "32-true": 51e12, "16-true": 1.513e15 / 2, "16-mixed": 1.513e15 / 2, "bf16-true": 1.513e15 / 2, "bf16-mixed": 1.513e15 / 2, "8-true": 3.026e15 / 2, "8-mixed": 3.026e15 / 2, }, # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf # sxm and pcie have same flop counts "a100": { "64-true": 19.5e12, "32-true": 19.5e12, "16-true": 312e12, "16-mixed": 312e12, "bf16-true": 312e12, "bf16-mixed": 312e12, }, # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/a10-datasheet.pdf "a10g": {"32-true": 31.2e12, "16-true": 125e12, "16-mixed": 125e12, "bf16-true": 125e12, "bf16-mixed": 125e12}, # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf "v100-sxm": {"64-true": 7.8e12, "32-true": 15.7e12, "16-true": 125e12, "16-mixed": 125e12}, "v100-pcie": {"64-true": 7e12, "32-true": 14e12, "16-true": 112e12, "16-mixed": 112e12}, "v100s-pcie": {"64-true": 8.2e12, "32-true": 16.4e12, "16-true": 130e12, "16-mixed": 130e12}, # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf # sxm and pcie have same flop counts "t4": {"32-true": 8.1e12, "16-true": 65e12, "16-mixed": 65e12, "8-true": 130e12, "int4": 260e12}, # https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/quadro-rtx-5000-data-sheet-us-nvidia-704120-r4-web.pdf "quadro rtx 5000": {"32-true": 11.2e12, "16-true": 89.2e12, "16-mixed": 89.2e12}, } TPU_AVAILABLE_FLOPS = { # flop count for each TPU generation is the same for all precisions # since bfloat16 precision is always used for performing matrix operations # for more info: https://cloud.google.com/tpu/docs/bfloat16#choosing_bfloat16 # source: https://arxiv.org/pdf/1907.10701.pdf "v2": 45e12, # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v3 "v3": 123e12, # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4 "v4": 275e12, } def get_flops_available(device: torch.device, precision: str) -> Optional[float]: if device.type == "cuda": device_name = torch.cuda.get_device_name(device).lower() if "h100" in device_name and "hbm3" in device_name: device_name = "h100-sxm" elif "h100" in device_name and ("pcie" in device_name or "hbm2e" in device_name): device_name = "h100-pcie" elif "a100" in device_name: device_name = "a100" elif "a10g" in device_name: device_name = "a10g" elif "v100-sxm" in device_name: device_name = "v100-sxm" elif "v100-pcie" in device_name: device_name = "v100-pcie" elif "t4" in device_name: device_name = "t4" elif "quadro rtx 5000" in device_name: device_name = "quadro rtx 5000" else: device_name = None if device_name is not None: try: return int(GPU_AVAILABLE_FLOPS[device_name][precision]) except KeyError: raise KeyError( f"flop count not found for {device_name} with precision: {precision}; " "MFU cannot be calculated and reported." ) elif device.type == "xla": from torch_xla.experimental import tpu device_name = tpu.get_tpu_env()["TYPE"].lower() try: return int(TPU_AVAILABLE_FLOPS[device_name]) except KeyError: raise KeyError( f"flop count not found for {device_name} with precision: {precision}; " "MFU cannot be calculated and reported." ) return None # Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820cb75023b9eb7c46fdfd25273712abd0/composer/callbacks/speed_monitor.py class SpeedMonitorBase: """Logs the training throughput and utilization. +-------------------------------------+-----------------------------------------------------------+ | Key | Logged data | +=====================================+===========================================================+ | | Rolling average (over `window_size` most recent | | `throughput/batches_per_sec` | batches) of the number of batches processed per second | | | | +-------------------------------------+-----------------------------------------------------------+ | | Rolling average (over `window_size` most recent | | `throughput/samples_per_sec` | batches) of the number of samples processed per second | | | | +-------------------------------------+-----------------------------------------------------------+ | | Rolling average (over `window_size` most recent | | `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. | | | This may include padding depending on dataset | +-------------------------------------+-----------------------------------------------------------+ | | Estimates flops by `flops_per_batch * batches_per_sec` | | `throughput/flops_per_sec` | | | | | +-------------------------------------+-----------------------------------------------------------+ | `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size | +-------------------------------------+-----------------------------------------------------------+ | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size | +-------------------------------------+-----------------------------------------------------------+ | | `throughput/tokens_per_sec` divided by world size. This | | `throughput/device/tokens_per_sec` | may include pad tokens depending on dataset | | | | +-------------------------------------+-----------------------------------------------------------+ | | `throughput/flops_per_sec` divided by world size. Only | | `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` | | | | +-------------------------------------+-----------------------------------------------------------+ | | `throughput/device/flops_per_sec` divided by world size. | | `throughput/device/mfu` | | | | | +-------------------------------------+-----------------------------------------------------------+ | `time/train` | Total elapsed training time | +-------------------------------------+-----------------------------------------------------------+ | `time/val` | Total elapsed validation time | +-------------------------------------+-----------------------------------------------------------+ | `time/total` | Total elapsed time (time/train + time/val) | +-------------------------------------+-----------------------------------------------------------+ Notes: - The implementation assumes that devices are homogeneous as it normalizes by the world size. - Tokens/sec, flops/sec and MFU do not account for padding tokens if present. We suggest using samples/sec or batches/sec to measure throughput under this circumstance. - Be careful when comparing MFU numbers across projects, as this will highly depend on the ``flops_per_batch``. There is no widespread, realistic, and reliable implementation to compute them. We suggest using our ``measure_flops`` function, but many other works will use ``estimated_flops`` which will almost always be an overestimate when compared to the true value. Args: window_size (int, optional): Number of batches to use for a rolling average of throughput. Defaults to 100. time_unit (str, optional): Time unit to use for `time` logging. Can be one of 'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'. """ def __init__( self, flops_available: float, log_dict: Callable[[Dict, int], None], window_size: int = 100, time_unit: str = "hours", log_iter_interval: int = 1, ): self.flops_available = flops_available self.log_dict = log_dict self.log_iter_interval = log_iter_interval # Track the batch num samples and wct to compute throughput over a window of batches self.history_samples: Deque[int] = deque(maxlen=window_size + 1) self.history_training_loss: Deque[int] = deque(maxlen=log_iter_interval) self.history_wct: Deque[float] = deque(maxlen=window_size + 1) self.history_lengths: Deque[int] = deque(maxlen=window_size + 1) self.history_flops: Deque[int] = deque(maxlen=window_size + 1) self.divider = 1 if time_unit == "seconds": self.divider = 1 elif time_unit == "minutes": self.divider = 60 elif time_unit == "hours": self.divider = 60 * 60 elif time_unit == "days": self.divider = 60 * 60 * 24 else: raise ValueError( f'Invalid time_unit: {time_unit}. Must be one of "seconds", "minutes", "hours", or "days".' ) # Keep track of time spent evaluating self.total_eval_wct = 0.0 self.iter = -1 def on_train_batch_end( self, samples: int, # total samples seen (per device) train_elapsed: float, # total training time (seconds) world_size: int, step_count: int, flops_per_batch: Optional[int] = None, # (per device) lengths: Optional[int] = None, # total length of the samples seen (per device) train_loss: Optional[float] = None, ): self.iter += 1 metrics = {} self.history_samples.append(samples) self.history_training_loss.append(train_loss) if lengths is not None: self.history_lengths.append(lengths) # if lengths are passed, there should be as many values as samples assert len(self.history_samples) == len(self.history_lengths) self.history_wct.append(train_elapsed) if len(self.history_wct) == self.history_wct.maxlen: elapsed_batches = len(self.history_samples) - 1 elapsed_samples = self.history_samples[-1] - self.history_samples[0] elapsed_wct = self.history_wct[-1] - self.history_wct[0] samples_per_sec = elapsed_samples * world_size / elapsed_wct dev_samples_per_sec = elapsed_samples / elapsed_wct metrics.update( { "throughput/batches_per_sec": elapsed_batches * world_size / elapsed_wct, "throughput/samples_per_sec": samples_per_sec, "throughput/device/batches_per_sec": elapsed_batches / elapsed_wct, "throughput/device/samples_per_sec": dev_samples_per_sec, } ) if lengths is not None: elapsed_lengths = int(self.history_lengths[-1]) - int(self.history_lengths[0]) avg_length = elapsed_lengths / elapsed_batches metrics.update( { "throughput/tokens_per_sec": samples_per_sec * avg_length, "throughput/device/tokens_per_sec": dev_samples_per_sec * avg_length, "total_tokens": avg_length * world_size * samples, } ) if train_loss is not None: avg_loss = sum(self.history_training_loss) / len(self.history_training_loss) metrics.update( { "metric/train_loss": avg_loss, "metric/train_ppl": math.exp(avg_loss) } ) if flops_per_batch is not None: # sum of flops per batch across ranks self.history_flops.append(flops_per_batch * world_size) if len(self.history_flops) == self.history_flops.maxlen: elapsed_flops = sum(self.history_flops) - self.history_flops[0] elapsed_wct = self.history_wct[-1] - self.history_wct[0] flops_per_sec = elapsed_flops / elapsed_wct device_flops_per_sec = flops_per_sec / world_size metrics.update( {"throughput/flops_per_sec": flops_per_sec, "throughput/device/flops_per_sec": device_flops_per_sec} ) if self.flops_available: metrics["throughput/device/mfu"] = device_flops_per_sec / self.flops_available metrics.update( { "time/train": train_elapsed / self.divider, "time/val": self.total_eval_wct / self.divider, "time/total": (train_elapsed + self.total_eval_wct) / self.divider, "samples": samples, } ) if self.iter % self.log_iter_interval == 0: self.log_dict(metrics, step_count) def eval_end(self, eval_elapsed: float): self.total_eval_wct += eval_elapsed # seconds class SpeedMonitorFabric(SpeedMonitorBase): def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None: # TODO: this will not work properly if a precision plugin is passed to Fabric flops_available = get_flops_available(fabric.device, fabric._connector._precision_input) super().__init__(flops_available, fabric.log_dict, *args, **kwargs) @fabric_rank_zero_only def on_train_batch_end(self, *args: Any, **kwargs: Any): super().on_train_batch_end(*args, **kwargs) class SpeedMonitorCallback(Callback): def __init__(self, length_fn: Callable[[Any], int], batch_size: int, **kwargs: Any) -> None: super().__init__() self.speed_monitor: Optional[SpeedMonitorBase] = None self.speed_monitor_kwargs = kwargs self.length_fn = length_fn self.batch_size = batch_size self.eval_t0: int = 0 self.train_t0: int = 0 self.total_lengths: int = 0 def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: if self.speed_monitor is not None: return # already setup # TODO: this will not work properly if a precision plugin is passed to Trainer flops_available = get_flops_available( trainer.strategy.root_device, trainer._accelerator_connector._precision_flag ) self.speed_monitor = SpeedMonitorBase(flops_available, trainer.logger.log_metrics, **self.speed_monitor_kwargs) @trainer_rank_zero_only def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: if trainer.fit_loop._should_accumulate(): return self.train_t0 = time.perf_counter() @trainer_rank_zero_only def on_train_batch_end( self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int ) -> None: self.total_lengths += self.length_fn(batch) if trainer.fit_loop._should_accumulate(): return train_elapsed = time.perf_counter() - self.train_t0 assert self.speed_monitor is not None iter_num = trainer.fit_loop.total_batch_idx assert (measured_flops := pl_module.measured_flops) is not None self.speed_monitor.on_train_batch_end( (iter_num + 1) * self.batch_size, train_elapsed, # this assumes that device FLOPs are the same and that all devices have the same batch size trainer.world_size, flops_per_batch=measured_flops, lengths=self.total_lengths, ) @trainer_rank_zero_only def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None: self.eval_t0 = time.perf_counter() @trainer_rank_zero_only def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None: eval_elapsed = time.perf_counter() - self.eval_t0 assert self.speed_monitor is not None self.speed_monitor.eval_end(eval_elapsed) def flops_per_param(config: Config, n_params: int) -> int: flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation # this assumes that all samples have a fixed length equal to the block size # which is most likely false during finetuning flops_per_seq = flops_per_token * config.block_size attn_flops_per_seq = config.n_layer * 2 * 2 * (config.n_embd * (config.block_size**2)) return flops_per_seq + attn_flops_per_seq def estimate_flops(model: GPT) -> int: """Measures estimated FLOPs for MFU. Refs: * https://ar5iv.labs.arxiv.org/html/2205.05198#A1 * https://ar5iv.labs.arxiv.org/html/2204.02311#A2 """ # using all parameters for this is a naive over estimation because not all model parameters actually contribute to # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage # (~10%) compared to the measured FLOPs, making those lower but more realistic. # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper. n_trainable_params = num_parameters(model, requires_grad=True) trainable_flops = flops_per_param(model.config, n_trainable_params) # forward + backward + gradients (assumes no gradient accumulation) ops_per_step = 3 if model.training else 1 n_frozen_params = num_parameters(model, requires_grad=False) frozen_flops = flops_per_param(model.config, n_frozen_params) # forward + backward frozen_ops_per_step = 2 if model.training else 1 return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops def measure_flops(model: GPT, x: torch.Tensor) -> int: """Measures real FLOPs for HFU""" flop_counter = FlopCounterMode(model, display=False) ctx = nullcontext() if model.training else torch.no_grad() with ctx, flop_counter: y = model(x) if model.training: y.sum().backward() return flop_counter.get_total_flops() ================================================ FILE: lit_gpt/tokenizer.py ================================================ import json from pathlib import Path from typing import Optional import torch class Tokenizer: def __init__(self, checkpoint_dir: Path) -> None: # some checkpoints have both files, `.model` takes precedence if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file(): from sentencepiece import SentencePieceProcessor self.processor = SentencePieceProcessor(model_file=str(vocabulary_path)) self.backend = "sentencepiece" self.bos_id = self.processor.bos_id() self.eos_id = self.processor.eos_id() elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file(): from tokenizers import Tokenizer as HFTokenizer self.processor = HFTokenizer.from_file(str(vocabulary_path)) self.backend = "huggingface" with open(checkpoint_dir / "tokenizer_config.json") as fp: config = json.load(fp) bos_token = config.get("bos_token") self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None self.eos_id = self.token_to_id(config["eos_token"]) else: raise NotImplementedError @property def vocab_size(self) -> int: if self.backend == "huggingface": return self.processor.get_vocab_size(with_added_tokens=False) if self.backend == "sentencepiece": return self.processor.vocab_size() raise RuntimeError def token_to_id(self, token: str) -> int: if self.backend == "huggingface": id_ = self.processor.token_to_id(token) elif self.backend == "sentencepiece": id_ = self.processor.piece_to_id(token) else: raise RuntimeError if id_ is None: raise ValueError(f"token {token!r} not found in the collection.") return id_ def encode( self, string: str, device: Optional[torch.device] = None, bos: bool = False, eos: bool = True, max_length: int = -1, ) -> torch.Tensor: if self.backend == "huggingface": tokens = self.processor.encode(string).ids elif self.backend == "sentencepiece": tokens = self.processor.encode(string) else: raise RuntimeError if bos: bos_id = self.bos_id if bos_id is None: raise NotImplementedError("This tokenizer does not defined a bos token") tokens = [bos_id] + tokens if eos: tokens = tokens + [self.eos_id] if max_length > 0: tokens = tokens[:max_length] return torch.tensor(tokens, dtype=torch.int, device=device) def decode(self, tensor: torch.Tensor) -> str: tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist() return self.processor.decode(tokens) ================================================ FILE: lit_gpt/utils.py ================================================ """Utility functions for training and inference.""" import pickle import sys import warnings from contextlib import contextmanager from functools import partial from io import BytesIO from pathlib import Path from types import MethodType from typing import Any, Dict, List, Mapping, Optional, Type, TypeVar, Union import torch import torch.nn as nn import torch.utils._device from lightning.fabric.loggers import CSVLogger from torch.serialization import normalize_storage_type def find_multiple(n: int, k: int) -> int: assert k > 0 if n % k == 0: return n return n + k - (n % k) def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int: return sum(p.numel() for p in module.parameters() if requires_grad is None or p.requires_grad == requires_grad) @contextmanager def quantization(mode: Optional[str] = None): if mode is None: yield return if mode == "bnb.int8": from quantize.bnb import InferenceLinear8bitLt quantized_linear_cls = InferenceLinear8bitLt elif mode == "bnb.fp4": from quantize.bnb import Linear4bit # Use a class instead `functools.partial` to respect `isinstance` checks and attribute accesses class QuantizedLinear(Linear4bit): def __init__(self, *args, **kwargs): super().__init__(*args, quant_type="fp4", compress_statistics=False, **kwargs) quantized_linear_cls = QuantizedLinear elif mode == "bnb.fp4-dq": from quantize.bnb import Linear4bit class QuantizedLinear(Linear4bit): def __init__(self, *args, **kwargs): super().__init__(*args, quant_type="fp4", compress_statistics=True, **kwargs) quantized_linear_cls = QuantizedLinear elif mode == "bnb.nf4": from quantize.bnb import Linear4bit class QuantizedLinear(Linear4bit): def __init__(self, *args, **kwargs): super().__init__(*args, quant_type="nf4", compress_statistics=False, **kwargs) quantized_linear_cls = QuantizedLinear elif mode == "bnb.nf4-dq": from quantize.bnb import Linear4bit class QuantizedLinear(Linear4bit): def __init__(self, *args, **kwargs): super().__init__(*args, quant_type="nf4", compress_statistics=True, **kwargs) quantized_linear_cls = QuantizedLinear elif mode == "gptq.int4": from quantize.gptq import ColBlockQuantizedLinear class QuantizedLinear(ColBlockQuantizedLinear): def __init__(self, *args, **kwargs): super().__init__(*args, bits=4, tile_cols=-1, **kwargs) quantized_linear_cls = QuantizedLinear else: raise ValueError(f"Unknown quantization mode: {mode}") torch_linear_cls = torch.nn.Linear torch.nn.Linear = quantized_linear_cls yield torch.nn.Linear = torch_linear_cls # this is taken from torchhacks https://github.com/lernapparat/torchhacks class NotYetLoadedTensor: def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args): self.metatensor = metatensor self.archiveinfo = archiveinfo self.storageinfo = storageinfo self.rebuild_args = rebuild_args @classmethod def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None): ret = func(*args) if isinstance(ret, NotYetLoadedTensor): old_lt = ret._load_tensor def _load_tensor(): t = old_lt() return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state) ret._load_tensor = _load_tensor return ret return torch._tensor._rebuild_from_type_v2(func, new_type, args, state) @classmethod def rebuild_parameter(cls, data, requires_grad, backward_hooks, *, archiveinfo=None): if isinstance(data, NotYetLoadedTensor): old_lt = data._load_tensor def _load_tensor(): t = old_lt() return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks) data._load_tensor = _load_tensor return data return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks) @classmethod def rebuild_tensor_v2( cls, storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None, *, archiveinfo=None ): rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata) metatensor = torch._utils._rebuild_tensor_v2( storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata ) storageinfo = storage.archiveinfo return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) def _load_tensor(self): name, storage_cls, fn, device, size = self.storageinfo dtype = self.metatensor.dtype uts = ( self.archiveinfo.zipfile_context.zf.get_storage_from_record( f"data/{fn}", size * torch._utils._element_size(dtype), torch.UntypedStorage ) ._typed_storage() ._untyped_storage ) with warnings.catch_warnings(): warnings.simplefilter("ignore") storage = torch.storage.TypedStorage(wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True) return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} loaded_args = [(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args] return func(*loaded_args, **kwargs) # gc.collect would be costly here, maybe do it optionally def __getattr__(self, name): # properties ## TODO: device, is_...?? ## TODO: mH, mT, H, T, data, imag, real ## name ??? if name in { "dtype", "grad", "grad_fn", "layout", "names", "ndim", "output_nr", "requires_grad", "retains_grad", "shape", "volatile", }: return getattr(self.metatensor, name) if name in {"size"}: return getattr(self.metatensor, name) # materializing with contiguous is needed for quantization if name in {"contiguous"}: return getattr(self._load_tensor(), name) raise AttributeError(f"{type(self)} does not have {name}") def __repr__(self): return f"NotYetLoadedTensor({repr(self.metatensor)})" class LazyLoadingUnpickler(pickle.Unpickler): def __init__(self, file, zipfile_context): super().__init__(file) self.zipfile_context = zipfile_context def find_class(self, module, name): res = super().find_class(module, name) if module == "torch._utils" and name == "_rebuild_tensor_v2": return partial(NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self) if module == "torch._tensor" and name == "_rebuild_from_type_v2": return partial(NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self) if module == "torch._utils" and name == "_rebuild_parameter": return partial(NotYetLoadedTensor.rebuild_parameter, archiveinfo=self) return res def persistent_load(self, pid): name, cls, fn, device, size = pid with warnings.catch_warnings(): warnings.simplefilter("ignore") s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta") s.archiveinfo = pid return s class lazy_load: def __init__(self, fn): self.zf = torch._C.PyTorchFileReader(str(fn)) with BytesIO(self.zf.get_record("data.pkl")) as pkl: mup = LazyLoadingUnpickler(pkl, self) self.sd = mup.load() def __enter__(self): return self.sd def __exit__(self, exc_type, exc_val, exc_tb): del self.zf # I don't think there is a way to force closing... self.zf = None def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None: files = { "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(), "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or ( checkpoint_dir / "tokenizer.model" ).is_file(), "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(), } if checkpoint_dir.is_dir(): if all(files.values()): # we're good return problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}" else: problem = " is not a checkpoint directory" # list locally available checkpoints available = list(Path("checkpoints").glob("*/*")) if available: options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available]) extra = f"\nYou have downloaded locally:{options}\n" else: extra = "" error_message = ( f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}." "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n" f"{extra}\nSee all download options by running:\n python scripts/download.py" ) print(error_message, file=sys.stderr) raise SystemExit(1) class SavingProxyForStorage: def __init__(self, obj, saver, protocol_version=5): self.protocol_version = protocol_version self.saver = saver if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)): raise TypeError(f"expected storage, not {type(obj)}") # this logic is taken from PyTorch 2.0+ torch/serialization.py if isinstance(obj, torch.storage.TypedStorage): # PT upstream wants to deprecate this eventually... storage = obj._untyped_storage storage_type_str = obj._pickle_storage_type() storage_type = getattr(torch, storage_type_str) storage_numel = obj._size() else: storage = obj storage_type = normalize_storage_type(type(obj)) storage_numel = storage.nbytes() storage_key = saver._write_storage_and_return_key(storage) location = torch.serialization.location_tag(storage) self.storage_info = ("storage", storage_type, storage_key, location, storage_numel) def __reduce_ex__(self, protocol_version): assert False, "this should be handled with out of band" class SavingProxyForTensor: def __init__(self, tensor, saver, protocol_version=5): self.protocol_version = protocol_version self.reduce_ret_fn, (storage, *other_reduce_args) = tensor.__reduce_ex__(protocol_version) assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates" storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version) self.reduce_args = (storage_proxy, *other_reduce_args) def __reduce_ex__(self, protocol_version): if protocol_version != self.protocol_version: raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}") return self.reduce_ret_fn, self.reduce_args class IncrementalPyTorchPickler(pickle.Pickler): def __init__(self, saver, *args, **kwargs): super().__init__(*args, **kwargs) self.storage_dtypes = {} self.saver = saver self.id_map = {} # this logic is taken from PyTorch 2.0+ torch/serialization.py def persistent_id(self, obj): # FIXME: the docs say that persistent_id should only return a string # but torch store returns tuples. This works only in the binary protocol # see # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 if isinstance(obj, SavingProxyForStorage): return obj.storage_info if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, this case # can be deleted storage = obj._untyped_storage storage_dtype = obj.dtype storage_type_str = obj._pickle_storage_type() storage_type = getattr(torch, storage_type_str) storage_numel = obj._size() else: storage = obj storage_dtype = torch.uint8 storage_type = normalize_storage_type(type(obj)) storage_numel = storage.nbytes() # If storage is allocated, ensure that any other saved storages # pointing to the same data all have the same dtype. If storage is # not allocated, don't perform this check if storage.data_ptr() != 0: if storage.data_ptr() in self.storage_dtypes: if storage_dtype != self.storage_dtypes[storage.data_ptr()]: raise RuntimeError( "Cannot save multiple tensors or storages that view the same data as different types" ) else: self.storage_dtypes[storage.data_ptr()] = storage_dtype storage_key = self.id_map.get(storage._cdata) if storage_key is None: storage_key = self.saver._write_storage_and_return_key(storage) self.id_map[storage._cdata] = storage_key location = torch.serialization.location_tag(storage) return ("storage", storage_type, storage_key, location, storage_numel) return None class incremental_save: def __init__(self, name): self.name = name self.zipfile = torch._C.PyTorchFileWriter(str(name)) self.has_saved = False self.next_key = 0 def __enter__(self): return self def store_early(self, tensor): if isinstance(tensor, torch.Tensor): return SavingProxyForTensor(tensor, self) raise TypeError(f"can only store tensors early, not {type(tensor)}") def save(self, obj): if self.has_saved: raise RuntimeError("have already saved") # Write the pickle data for `obj` data_buf = BytesIO() pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5) pickler.dump(obj) data_value = data_buf.getvalue() self.zipfile.write_record("data.pkl", data_value, len(data_value)) self.has_saved = True def _write_storage_and_return_key(self, storage): if self.has_saved: raise RuntimeError("have already saved") key = self.next_key self.next_key += 1 name = f"data/{key}" if storage.device.type != "cpu": storage = storage.cpu() num_bytes = storage.nbytes() self.zipfile.write_record(name, storage.data_ptr(), num_bytes) return key def __exit__(self, type, value, traceback): self.zipfile.write_end_of_file() T = TypeVar("T") def step_csv_logger(*args: Any, cls: Type[T] = CSVLogger, **kwargs: Any) -> T: logger = cls(*args, **kwargs) def merge_by(dicts, key): from collections import defaultdict out = defaultdict(dict) for d in dicts: if key in d: out[d[key]].update(d) return [v for _, v in sorted(out.items())] def save(self) -> None: """Overridden to merge CSV by the step number.""" import csv if not self.metrics: return metrics = merge_by(self.metrics, "step") keys = sorted({k for m in metrics for k in m}) with self._fs.open(self.metrics_file_path, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=keys) writer.writeheader() writer.writerows(metrics) logger.experiment.save = MethodType(save, logger.experiment) return logger def chunked_cross_entropy( logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128 ) -> torch.Tensor: # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate # the memory usage in fine-tuning settings with low number of parameters. # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing # the memory spike's magnitude # lm_head was chunked (we are fine-tuning) if isinstance(logits, list): # don't want to chunk cross entropy if chunk_size == 0: logits = torch.cat(logits, dim=1) logits = logits.reshape(-1, logits.size(-1)) targets = targets.reshape(-1) return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) # chunk cross entropy logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits] target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)] loss_chunks = [ torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) ] return torch.cat(loss_chunks).mean() # no chunking at all logits = logits.reshape(-1, logits.size(-1)) targets = targets.reshape(-1) if chunk_size == 0: return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) # lm_head wasn't chunked, chunk cross entropy logit_chunks = logits.split(chunk_size) target_chunks = targets.split(chunk_size) loss_chunks = [ torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) ] return torch.cat(loss_chunks).mean() def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict: for checkpoint_name, attribute_name in mapping.items(): full_checkpoint_name = prefix + checkpoint_name if full_checkpoint_name in state_dict: full_attribute_name = prefix + attribute_name state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name) return state_dict def get_default_supported_precision(training: bool, tpu: bool = False) -> str: """Return default precision that is supported by the hardware. Args: training: `-mixed` or `-true` version of the precision to use tpu: whether TPU device is used Returns: default precision that is suitable for the task and is supported by the hardware """ if tpu: return "32-true" if not torch.cuda.is_available() or torch.cuda.is_bf16_supported(): return "bf16-mixed" if training else "bf16-true" return "16-mixed" if training else "16-true" ================================================ FILE: pretrain/tinyllama.py ================================================ import glob import math import sys import time from pathlib import Path from typing import Optional, Tuple, Union import math import lightning as L import torch from lightning.fabric.strategies import FSDPStrategy, XLAStrategy from torch.utils.data import DataLoader from functools import partial # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) # from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually from lit_gpt.model import GPT, Block, Config, CausalSelfAttention from lit_gpt.packed_dataset import CombinedDataset, PackedDataset from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor from lit_gpt.speed_monitor import estimate_flops, measure_flops from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load from pytorch_lightning.loggers import WandbLogger from lit_gpt import FusedCrossEntropyLoss import random model_name = "tiny_LLaMA_1b" name = "tinyllama_1b" out_dir = Path("out") / name # Hyperparameters num_of_devices = 8 global_batch_size = 512 learning_rate = 4e-4 micro_batch_size = 8 max_step = 715256 * 2 warmup_steps = 2000 log_step_interval = 10 eval_iters = 100 save_step_interval = 5000 eval_step_interval = 5000 weight_decay = 1e-1 beta1 = 0.9 beta2 = 0.95 grad_clip = 1.0 decay_lr = True min_lr = 4e-5 batch_size = global_batch_size // num_of_devices gradient_accumulation_steps = batch_size // micro_batch_size assert gradient_accumulation_steps > 0 warmup_iters = warmup_steps * gradient_accumulation_steps max_iters = max_step * gradient_accumulation_steps lr_decay_iters = max_iters log_iter_interval = log_step_interval * gradient_accumulation_steps # Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. train_data_config = [ ("train_slim", 0.693584), ("train_star", 0.306416), ] val_data_config = [ ("validation", 1.0), ] hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) wandb_logger = WandbLogger() def setup( devices: int = 8, train_data_dir: Path = Path("data/redpajama_sample"), val_data_dir: Optional[Path] = None, precision: Optional[str] = None, tpu: bool = False, resume: Union[bool, Path] = False, ) -> None: precision = precision or get_default_supported_precision(training=True, tpu=tpu) if devices > 1: if tpu: # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. devices = "auto" strategy = XLAStrategy(sync_module_states=False) else: strategy = FSDPStrategy( auto_wrap_policy={Block}, activation_checkpointing_policy=None, state_dict_type="full", limit_all_gathers=True, cpu_offload=False, ) else: strategy = "auto" fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) fabric.print(hparams) #fabric.launch(main, train_data_dir, val_data_dir, resume) main(fabric, train_data_dir, val_data_dir, resume) def main(fabric, train_data_dir, val_data_dir, resume): monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) if fabric.global_rank == 0: out_dir.mkdir(parents=True, exist_ok=True) config = Config.from_name(model_name) train_dataloader, val_dataloader = create_dataloaders( batch_size=micro_batch_size, block_size=config.block_size, fabric=fabric, train_data_dir=train_data_dir, val_data_dir=val_data_dir, seed=3407, ) if val_dataloader is None: train_dataloader = fabric.setup_dataloaders(train_dataloader) else: train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) fabric.seed_everything(3407) # same seed for every process to init model (FSDP) fabric.print(f"Loading model with {config.__dict__}") t0 = time.perf_counter() with fabric.init_module(empty_init=False): model = GPT(config) model.apply(partial(model._init_weights ,n_layer=config.n_layer)) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") fabric.print(f"Total parameters {num_parameters(model):,}") model = fabric.setup(model) optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False ) # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) optimizer = fabric.setup_optimizers(optimizer) state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} if resume is True: resume = sorted(out_dir.glob("*.pth"))[-1] if resume : fabric.print(f"Resuming training from {resume}") fabric.load(resume, state) train_time = time.perf_counter() train(fabric, state, train_dataloader, val_dataloader, monitor, resume) fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): model = state["model"] optimizer = state["optimizer"] if val_dataloader is not None: validate(fabric, model, val_dataloader) # sanity check with torch.device("meta"): meta_model = GPT(model.config) # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead estimated_flops = estimate_flops(meta_model) * micro_batch_size fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) # measured_flos run in meta. Will trigger fusedRMSNorm error #measured_flops = measure_flops(meta_model, x) #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") del meta_model, x total_lengths = 0 total_t0 = time.perf_counter() if fabric.device.type == "xla": import torch_xla.core.xla_model as xm xm.mark_step() initial_iter = state["iter_num"] curr_iter = 0 loss_func = FusedCrossEntropyLoss() for train_data in train_dataloader: # resume loader state. This is not elegant but it works. Should rewrite it in the future. if resume: if curr_iter < initial_iter: curr_iter += 1 continue else: resume = False curr_iter = -1 fabric.barrier() fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) if state["iter_num"] >= max_iters: break # determine and set the learning rate for this iteration lr = get_lr(state["iter_num"]) if decay_lr else learning_rate for param_group in optimizer.param_groups: param_group["lr"] = lr iter_t0 = time.perf_counter() input_ids = train_data[:, 0 : model.config.block_size].contiguous() targets = train_data[:, 1 : model.config.block_size + 1].contiguous() is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): logits = model(input_ids) loss = loss_func(logits, targets) # loss = chunked_cross_entropy(logits, targets, chunk_size=0) fabric.backward(loss / gradient_accumulation_steps) if not is_accumulating: fabric.clip_gradients(model, optimizer, max_norm=grad_clip) optimizer.step() optimizer.zero_grad() state["step_count"] += 1 elif fabric.device.type == "xla": xm.mark_step() state["iter_num"] += 1 # input_id: B L total_lengths += input_ids.size(1) t1 = time.perf_counter() fabric.print( f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " # print days as well f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " ) monitor.on_train_batch_end( state["iter_num"] * micro_batch_size, t1 - total_t0, # this assumes that device FLOPs are the same and that all devices have the same batch size fabric.world_size, state["step_count"], flops_per_batch=estimated_flops, lengths=total_lengths, train_loss = loss.item() ) if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: t0 = time.perf_counter() val_loss = validate(fabric, model, val_dataloader) t1 = time.perf_counter() - t0 monitor.eval_end(t1) fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) fabric.barrier() if not is_accumulating and state["step_count"] % save_step_interval == 0: checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") fabric.save(checkpoint_path, state) @torch.no_grad() def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: fabric.print("Validating ...") model.eval() losses = torch.zeros(eval_iters, device=fabric.device) for k, val_data in enumerate(val_dataloader): if k >= eval_iters: break input_ids = val_data[:, 0 : model.config.block_size].contiguous() targets = val_data[:, 1 : model.config.block_size + 1].contiguous() logits = model(input_ids) loss = chunked_cross_entropy(logits, targets, chunk_size=0) # loss_func = FusedCrossEntropyLoss() # loss = loss_func(logits, targets) losses[k] = loss.item() out = losses.mean() model.train() return out def create_dataloader( batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" ) -> DataLoader: datasets = [] data_config = train_data_config if split == "train" else val_data_config for prefix, _ in data_config: filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) random.seed(seed) random.shuffle(filenames) dataset = PackedDataset( filenames, # n_chunks control the buffer size. # Note that the buffer size also impacts the random shuffle # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) n_chunks=8, block_size=block_size, shuffle=shuffle, seed=seed+fabric.global_rank, num_processes=fabric.world_size, process_rank=fabric.global_rank, ) datasets.append(dataset) if not datasets: raise RuntimeError( f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." ) weights = [weight for _, weight in data_config] sum_weights = sum(weights) weights = [el / sum_weights for el in weights] combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) def create_dataloaders( batch_size: int, block_size: int, fabric, train_data_dir: Path = Path("data/redpajama_sample"), val_data_dir: Optional[Path] = None, seed: int = 12345, ) -> Tuple[DataLoader, DataLoader]: # Increase by one because we need the next word as well effective_block_size = block_size + 1 train_dataloader = create_dataloader( batch_size=batch_size, block_size=effective_block_size, fabric=fabric, data_dir=train_data_dir, shuffle=True, seed=seed, split="train" ) val_dataloader = ( create_dataloader( batch_size=batch_size, block_size=effective_block_size, fabric=fabric, data_dir=val_data_dir, shuffle=False, seed=seed, split="validation" ) if val_data_dir else None ) return train_dataloader, val_dataloader # learning rate decay scheduler (cosine with warmup) def get_lr(it): # 1) linear warmup for warmup_iters steps if it < warmup_iters: return learning_rate * it / warmup_iters # 2) if it > lr_decay_iters, return min learning rate if it > lr_decay_iters: return min_lr # 3) in between, use cosine decay down to min learning rate decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 return min_lr + coeff * (learning_rate - min_lr) if __name__ == "__main__": # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" # torch.backends.cuda.enable_flash_sdp(False) torch.set_float32_matmul_precision("high") from jsonargparse import CLI CLI(setup) ================================================ FILE: pretrain/tinyllama_code.py ================================================ import glob import math import sys import time from pathlib import Path from typing import Optional, Tuple, Union import math import lightning as L import torch from lightning.fabric.strategies import FSDPStrategy, XLAStrategy from torch.utils.data import DataLoader from functools import partial # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) # from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually from lit_gpt.model import GPT, Block, Config, CausalSelfAttention from lit_gpt.packed_dataset import CombinedDataset, PackedDataset from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor from lit_gpt.speed_monitor import estimate_flops, measure_flops from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load from pytorch_lightning.loggers import WandbLogger from lit_gpt import FusedCrossEntropyLoss import random model_name = "tiny_LLaMA_1b" name = "tiny_LLaMA_1b" out_dir = Path("out") / name checkpoint_path = "out/TinyLlama-1.1B-intermediate-step-240k-503b/lit_model.pth" # Hyperparameters num_of_devices = 6 global_batch_size = 360 learning_rate = 2e-4 min_lr = 2e-5 micro_batch_size = 6 max_step = 10000 warmup_steps = 0 log_step_interval = 1 eval_iters = 1000000 save_step_interval = 2000 eval_step_interval = 2000 weight_decay = 1e-1 beta1 = 0.9 beta2 = 0.95 grad_clip = 1.0 decay_lr = True batch_size = global_batch_size // num_of_devices gradient_accumulation_steps = batch_size // micro_batch_size assert gradient_accumulation_steps > 0 warmup_iters = warmup_steps * gradient_accumulation_steps max_iters = max_step * gradient_accumulation_steps lr_decay_iters = max_iters log_iter_interval = log_step_interval * gradient_accumulation_steps # Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. train_data_config = [ ("train_starcoder", 1), ] val_data_config = [ ("validation", 1.0), ] hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) wandb_logger = WandbLogger() def setup( devices: int = 8, train_data_dir: Path = Path("data/redpajama_sample"), val_data_dir: Optional[Path] = None, precision: Optional[str] = None, tpu: bool = False, resume: Union[bool, Path] = False, ) -> None: precision = precision or get_default_supported_precision(training=True, tpu=tpu) if devices > 1: if tpu: # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. devices = "auto" strategy = XLAStrategy(sync_module_states=False) else: strategy = FSDPStrategy( auto_wrap_policy={Block}, activation_checkpointing_policy=None, state_dict_type="full", limit_all_gathers=True, cpu_offload=False, ) else: strategy = "auto" fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) fabric.print(hparams) fabric.launch(main, train_data_dir, val_data_dir, resume) # main(fabric, train_data_dir, val_data_dir, resume) def main(fabric, train_data_dir, val_data_dir, resume): monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) if fabric.global_rank == 0: out_dir.mkdir(parents=True, exist_ok=True) config = Config.from_name(model_name) train_dataloader, val_dataloader = create_dataloaders( batch_size=micro_batch_size, block_size=config.block_size, fabric=fabric, train_data_dir=train_data_dir, val_data_dir=val_data_dir, seed=3407, ) if val_dataloader is None: train_dataloader = fabric.setup_dataloaders(train_dataloader) else: train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) fabric.seed_everything(3407) # same seed for every process to init model (FSDP) fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}") t0 = time.perf_counter() with fabric.init_module(empty_init=True): model = GPT(config) model = fabric.setup(model) fabric.load_raw(checkpoint_path, model, strict=True) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") fabric.print(f"Total parameters {num_parameters(model):,}") optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False ) # import bitsandbytes as bnb # optimizer = bnb.optim.AdamW8bit( # model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2) # ) # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) optimizer = fabric.setup_optimizers(optimizer) state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} if resume is True: resume = sorted(out_dir.glob("*.pth"))[-1] if resume : fabric.print(f"Resuming training from {resume}") fabric.load(resume, state) train_time = time.perf_counter() train(fabric, state, train_dataloader, val_dataloader, monitor, resume) fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): model = state["model"] optimizer = state["optimizer"] if val_dataloader is not None: validate(fabric, model, val_dataloader) # sanity check with torch.device("meta"): meta_model = GPT(model.config) # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead estimated_flops = estimate_flops(meta_model) * micro_batch_size fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) # measured_flos run in meta. Will trigger fusedRMSNorm error #measured_flops = measure_flops(meta_model, x) #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") del meta_model, x total_lengths = 0 total_t0 = time.perf_counter() if fabric.device.type == "xla": import torch_xla.core.xla_model as xm xm.mark_step() initial_iter = state["iter_num"] curr_iter = 0 loss_func = FusedCrossEntropyLoss() for train_data in train_dataloader: # resume loader state. This is not elegant but it works. Should rewrite it in the future. if resume: if curr_iter < initial_iter: curr_iter += 1 continue else: resume = False curr_iter = -1 fabric.barrier() fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) if state["iter_num"] >= max_iters: break # determine and set the learning rate for this iteration lr = get_lr(state["iter_num"]) if decay_lr else learning_rate for param_group in optimizer.param_groups: param_group["lr"] = lr iter_t0 = time.perf_counter() input_ids = train_data[:, 0 : model.config.block_size].contiguous() targets = train_data[:, 1 : model.config.block_size + 1].contiguous() is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): logits = model(input_ids) loss = loss_func(logits, targets) # loss = chunked_cross_entropy(logits, targets, chunk_size=0) fabric.backward(loss / gradient_accumulation_steps) if not is_accumulating: fabric.clip_gradients(model, optimizer, max_norm=grad_clip) optimizer.step() optimizer.zero_grad() state["step_count"] += 1 elif fabric.device.type == "xla": xm.mark_step() state["iter_num"] += 1 # input_id: B L total_lengths += input_ids.size(1) t1 = time.perf_counter() fabric.print( f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " # print days as well f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " ) monitor.on_train_batch_end( state["iter_num"] * micro_batch_size, t1 - total_t0, # this assumes that device FLOPs are the same and that all devices have the same batch size fabric.world_size, state["step_count"], flops_per_batch=estimated_flops, lengths=total_lengths, train_loss = loss.item() ) if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: t0 = time.perf_counter() val_loss = validate(fabric, model, val_dataloader) t1 = time.perf_counter() - t0 monitor.eval_end(t1) fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) fabric.barrier() if not is_accumulating and state["step_count"] % save_step_interval == 0: checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") fabric.save(checkpoint_path, state) @torch.no_grad() def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: fabric.print("Validating ...") model.eval() losses = torch.zeros(eval_iters, device=fabric.device) for k, val_data in enumerate(val_dataloader): if k >= eval_iters: break input_ids = val_data[:, 0 : model.config.block_size].contiguous() targets = val_data[:, 1 : model.config.block_size + 1].contiguous() logits = model(input_ids) loss = chunked_cross_entropy(logits, targets, chunk_size=0) # loss_func = FusedCrossEntropyLoss() # loss = loss_func(logits, targets) losses[k] = loss.item() out = losses.mean() model.train() return out def create_dataloader( batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" ) -> DataLoader: datasets = [] data_config = train_data_config if split == "train" else val_data_config for prefix, _ in data_config: filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) random.seed(seed) random.shuffle(filenames) dataset = PackedDataset( filenames, # n_chunks control the buffer size. # Note that the buffer size also impacts the random shuffle # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) n_chunks=8, block_size=block_size, shuffle=shuffle, seed=seed+fabric.global_rank, num_processes=fabric.world_size, process_rank=fabric.global_rank, ) datasets.append(dataset) if not datasets: raise RuntimeError( f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." ) weights = [weight for _, weight in data_config] sum_weights = sum(weights) weights = [el / sum_weights for el in weights] combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) def create_dataloaders( batch_size: int, block_size: int, fabric, train_data_dir: Path = Path("data/redpajama_sample"), val_data_dir: Optional[Path] = None, seed: int = 12345, ) -> Tuple[DataLoader, DataLoader]: # Increase by one because we need the next word as well effective_block_size = block_size + 1 train_dataloader = create_dataloader( batch_size=batch_size, block_size=effective_block_size, fabric=fabric, data_dir=train_data_dir, shuffle=True, seed=seed, split="train" ) val_dataloader = ( create_dataloader( batch_size=batch_size, block_size=effective_block_size, fabric=fabric, data_dir=val_data_dir, shuffle=False, seed=seed, split="validation" ) if val_data_dir else None ) return train_dataloader, val_dataloader # learning rate decay scheduler (cosine with warmup) def get_lr(it): # 1) linear warmup for warmup_iters steps if it < warmup_iters: return learning_rate * it / warmup_iters # 2) if it > lr_decay_iters, return min learning rate if it > lr_decay_iters: return min_lr # 3) in between, use cosine decay down to min learning rate decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 return min_lr + coeff * (learning_rate - min_lr) if __name__ == "__main__": # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" # torch.backends.cuda.enable_flash_sdp(False) torch.set_float32_matmul_precision("high") from jsonargparse import CLI CLI(setup) ================================================ FILE: requirements.txt ================================================ torch>=2.1.0dev lightning==2.1.2 lightning[app] jsonargparse[signatures] # CLI pandas pyarrow tokenizers sentencepiece wandb zstd # for finetuning bitsandbytes==0.40.0 transformers==4.31.0 peft==0.4.0 accelerate==0.21.0 einops==0.6.1 evaluate==0.4.0 scikit-learn==1.2.2 sentencepiece==0.1.99 wandb==0.15.3 # other optional dependencies are # sentencepiece # pythia, falcon, redpajama # tokenizers # llama-based models # bitsandbytes>=0.41.1 # quantize/bnb.py # scipy # TODO: remove when https://github.com/TimDettmers/bitsandbytes/pull/525 is released # datasets # quantize/gptq.py # zstandard # scripts/prepare_redpajama.py # git+https://github.com/EleutherAI/lm-evaluation-harness.git@master # eval ================================================ FILE: script.sh ================================================ python scripts/convert_hf_checkpoint.py --checkpoint_dir out/TinyLlama-1.1B-900B --model_name tiny_LLaMA_1b python test_weight.py --checkpoint_dir out/TinyLlama-1.1B-intermediate-900B python pretrain/tinyllama_code.py --devices 8 --train_data_dir data/code_specialist_python_java_javascript_c_go_8192 python scripts/prepare_starcoder.py --source_path data/starcoderdata/ --tokenizer_path data/llama --destination_path data/code_specialist_python_java_javascript_c_go_8192 --split train --percentage 1.0 --filenames_subset ["python","cpp","go","java","javascript"] --chunk_size 4194816 /data/TinyLlama/out/code_tiny_LLaMA_1b_python_java_go_cpp_javascript/iter-032000-ckpt.pth python scripts/convert_lit_checkpoint.py --out_dir /data/TinyLlama/out/tiny_LLaMA_1b/ --checkpoint_name iter-100000-ckpt.pth --model_name tiny_LLaMA_1b ================================================ FILE: scripts/convert_hf_checkpoint.py ================================================ import contextlib import gc import json import sys from functools import partial from pathlib import Path from typing import Dict, List, Literal, Optional, Tuple, Union import torch # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) from lit_gpt import Config from lit_gpt.utils import NotYetLoadedTensor, incremental_save, lazy_load def copy_weights_gpt_neox( state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, ) -> None: weight_map = { "gpt_neox.embed_in.weight": "transformer.wte.weight", "gpt_neox.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", "gpt_neox.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", "gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.attn.bias", "gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", "gpt_neox.layers.{}.attention.dense.bias": "transformer.h.{}.attn.proj.bias", "gpt_neox.layers.{}.attention.dense.weight": "transformer.h.{}.attn.proj.weight", "gpt_neox.layers.{}.attention.rotary_emb.inv_freq": None, "gpt_neox.layers.{}.attention.bias": None, "gpt_neox.layers.{}.attention.masked_bias": None, "gpt_neox.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias", "gpt_neox.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias": "transformer.h.{}.mlp.fc.bias", "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias": "transformer.h.{}.mlp.proj.bias", "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", "gpt_neox.final_layer_norm.bias": "transformer.ln_f.bias", "gpt_neox.final_layer_norm.weight": "transformer.ln_f.weight", "embed_out.weight": "lm_head.weight", } for name, param in hf_weights.items(): if "gpt_neox.layers" in name: from_name, number = layer_template(name, 2) to_name = weight_map[from_name] if to_name is None: continue to_name = to_name.format(number) else: to_name = weight_map[name] param = load_param(param, name, dtype) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def copy_weights_falcon( size: Literal["7b", "40b"], state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, ) -> None: weight_map = { "transformer.word_embeddings.weight": "transformer.wte.weight", "transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", "transformer.h.{}.self_attention.dense.weight": "transformer.h.{}.attn.proj.weight", "transformer.h.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", "transformer.h.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", "transformer.ln_f.bias": "transformer.ln_f.bias", "transformer.ln_f.weight": "transformer.ln_f.weight", "lm_head.weight": "lm_head.weight", } # the original model definition is different for each size if size == "7b": weight_map.update( { "transformer.h.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", "transformer.h.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", } ) elif size == "40b": weight_map.update( { "transformer.h.{}.ln_attn.bias": "transformer.h.{}.norm_1.bias", "transformer.h.{}.ln_attn.weight": "transformer.h.{}.norm_1.weight", "transformer.h.{}.ln_mlp.bias": "transformer.h.{}.norm_2.bias", "transformer.h.{}.ln_mlp.weight": "transformer.h.{}.norm_2.weight", } ) else: raise NotImplementedError for name, param in hf_weights.items(): if "transformer.h" in name: from_name, number = layer_template(name, 2) to_name = weight_map[from_name].format(number) else: to_name = weight_map[name] param = load_param(param, name, dtype) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def copy_weights_hf_llama( config: Config, qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, ) -> None: weight_map = { "model.embed_tokens.weight": "transformer.wte.weight", "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", "model.layers.{}.self_attn.q_proj.weight": None, "model.layers.{}.self_attn.k_proj.weight": None, "model.layers.{}.self_attn.v_proj.weight": None, "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", "model.layers.{}.self_attn.rotary_emb.inv_freq": None, "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.swiglu.w1.weight", "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.swiglu.w2.weight", "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.swiglu.w3.weight", "model.norm.weight": "transformer.ln_f.weight", "lm_head.weight": "lm_head.weight", } for name, param in hf_weights.items(): if "model.layers" in name: from_name, number = layer_template(name, 2) qkv = qkv_weights.setdefault(number, [None, None, None]) if "q_proj" in name: qkv[0] = param elif "k_proj" in name: qkv[1] = param elif "v_proj" in name: qkv[2] = param to_name = weight_map[from_name] if to_name is None: continue to_name = to_name.format(number) else: to_name = weight_map[name] param = load_param(param, name, dtype) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param for i, (q, k, v) in list(qkv_weights.items()): if q is None or k is None or v is None: # split across different .bin files continue q = load_param(q, f"layer {i} q", dtype) k = load_param(k, f"layer {i} k", dtype) v = load_param(v, f"layer {i} v", dtype) q_per_kv = config.n_head // config.n_query_groups qs = torch.split(q, config.head_size * q_per_kv) ks = torch.split(k, config.head_size) vs = torch.split(v, config.head_size) cycled = [t for group in zip(qs, ks, vs) for t in group] qkv = torch.cat(cycled) state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv del qkv_weights[i] def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: split = layer_name.split(".") number = int(split[idx]) split[idx] = "{}" from_name = ".".join(split) return from_name, number def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor: if hasattr(param, "_load_tensor"): # support tensors loaded via `lazy_load()` print(f"Loading {name!r} into RAM") param = param._load_tensor() if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype: print(f"Converting {name!r} from {param.dtype} to {dtype}") param = param.to(dtype) return param @torch.inference_mode() def convert_hf_checkpoint( *, checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), model_name: Optional[str] = None, dtype: Optional[str] = None, ) -> None: if model_name is None: model_name = checkpoint_dir.name if dtype is not None: dtype = getattr(torch, dtype) config = Config.from_name(model_name) print(f"Model config {config.__dict__}") with open(checkpoint_dir / "lit_config.json", "w") as json_config: json.dump(config.__dict__, json_config) if "falcon" in model_name: copy_fn = partial(copy_weights_falcon, "40b" if config.n_embd == 8192 else "7b") elif config._mlp_class == "LLaMAMLP": # holder to reconstitute the split q, k, v qkv_weights = {} copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) else: copy_fn = copy_weights_gpt_neox # initialize a new empty state dict to hold our new weights sd = {} # Load the json file containing weight mapping pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json" if pytorch_bin_map_json_path.is_file(): # not all checkpoints have this file with open(pytorch_bin_map_json_path) as json_map: bin_index = json.load(json_map) bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} else: bin_files = set(checkpoint_dir.glob("*.bin")) if not bin_files: raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin files") with incremental_save(checkpoint_dir / "lit_model.pth") as saver: # for checkpoints that split the QKV across several files, we need to keep all the bin files # open, so we use `ExitStack` to close them all together at the end with contextlib.ExitStack() as stack: for bin_file in sorted(bin_files): print("Processing", bin_file) hf_weights = stack.enter_context(lazy_load(bin_file)) copy_fn(sd, hf_weights, saver=None, dtype=dtype) gc.collect() print("Saving converted checkpoint") saver.save(sd) if __name__ == "__main__": from jsonargparse import CLI CLI(convert_hf_checkpoint) ================================================ FILE: scripts/convert_lit_checkpoint.py ================================================ import contextlib import gc import sys from functools import partial from pathlib import Path from typing import Dict, Literal, Optional, Tuple, Union from dataclasses import asdict import json import torch # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) from lit_gpt import Config from lit_gpt.utils import NotYetLoadedTensor, incremental_save, lazy_load # from scripts.convert_hf_checkpoint import layer_template, load_param def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: split = layer_name.split(".") number = int(split[idx]) split[idx] = "{}" from_name = ".".join(split) return from_name, number def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor: if hasattr(param, "_load_tensor"): # support tensors loaded via `lazy_load()` print(f"Loading {name!r} into RAM") param = param._load_tensor() if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype: print(f"Converting {name!r} from {param.dtype} to {dtype}") param = param.to(dtype) return param def copy_weights_falcon( size: Literal["7b", "40b"], state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, ): weight_map = { "transformer.wte.weight": "transformer.word_embeddings.weight", "transformer.h.{}.attn.attn.weight": "transformer.h.{}.self_attention.query_key_value.weight", "transformer.h.{}.attn.proj.weight": "transformer.h.{}.self_attention.dense.weight", "transformer.h.{}.mlp.fc.weight": "transformer.h.{}.mlp.dense_h_to_4h.weight", "transformer.h.{}.mlp.proj.weight": "transformer.h.{}.mlp.dense_4h_to_h.weight", "transformer.ln_f.bias": "transformer.ln_f.bias", "transformer.ln_f.weight": "transformer.ln_f.weight", "lm_head.weight": "lm_head.weight", } # the original model definition is different for each size if size == "7b": weight_map.update( { "transformer.h.{}.norm_1.bias": "transformer.h.{}.input_layernorm.bias", "transformer.h.{}.norm_1.weight": "transformer.h.{}.input_layernorm.weight", } ) elif size == "40b": weight_map.update( { "transformer.h.{}.norm_1.bias": "transformer.h.{}.ln_attn.bias", "transformer.h.{}.norm_1.weight": "transformer.h.{}.ln_attn.weight", "transformer.h.{}.norm_2.bias": "transformer.h.{}.ln_mlp.bias", "transformer.h.{}.norm_2.weight": "transformer.h.{}.ln_mlp.weight", } ) else: raise NotImplementedError for name, param in lit_weights.items(): if "transformer.h" in name: from_name, number = layer_template(name, 2) to_name = weight_map[from_name].format(number) else: to_name = weight_map[name] param = load_param(param, name, None) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def copy_weights_gpt_neox( state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, ) -> None: weight_map = { "transformer.wte.weight": "gpt_neox.embed_in.weight", "transformer.h.{}.norm_1.bias": "gpt_neox.layers.{}.input_layernorm.bias", "transformer.h.{}.norm_1.weight": "gpt_neox.layers.{}.input_layernorm.weight", "transformer.h.{}.attn.attn.bias": "gpt_neox.layers.{}.attention.query_key_value.bias", "transformer.h.{}.attn.attn.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", "transformer.h.{}.attn.proj.bias": "gpt_neox.layers.{}.attention.dense.bias", "transformer.h.{}.attn.proj.weight": "gpt_neox.layers.{}.attention.dense.weight", "transformer.h.{}.norm_2.bias": "gpt_neox.layers.{}.post_attention_layernorm.bias", "transformer.h.{}.norm_2.weight": "gpt_neox.layers.{}.post_attention_layernorm.weight", "transformer.h.{}.mlp.fc.bias": "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias", "transformer.h.{}.mlp.fc.weight": "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight", "transformer.h.{}.mlp.proj.bias": "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias", "transformer.h.{}.mlp.proj.weight": "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight", "transformer.ln_f.bias": "gpt_neox.final_layer_norm.bias", "transformer.ln_f.weight": "gpt_neox.final_layer_norm.weight", "lm_head.weight": "embed_out.weight", } for name, param in lit_weights.items(): if "transformer.h" in name: from_name, number = layer_template(name, 2) to_name = weight_map[from_name].format(number) else: to_name = weight_map[name] param = load_param(param, name, None) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def copy_weights_llama( config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, ): weight_map = { "transformer.wte.weight": "model.embed_tokens.weight", "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", "transformer.h.{}.mlp.swiglu.w1.weight": "model.layers.{}.mlp.gate_proj.weight", "transformer.h.{}.mlp.swiglu.w2.weight": "model.layers.{}.mlp.up_proj.weight", "transformer.h.{}.mlp.swiglu.w3.weight": "model.layers.{}.mlp.down_proj.weight", "transformer.ln_f.weight": "model.norm.weight", "lm_head.weight": "lm_head.weight", } for name, param in lit_weights.items(): if name.endswith(".attn.attn.weight"): from_name, number = layer_template(name, 2) q = "model.layers.{}.self_attn.q_proj.weight".format(number) k = "model.layers.{}.self_attn.k_proj.weight".format(number) v = "model.layers.{}.self_attn.v_proj.weight".format(number) qkv = load_param(param, name,None) qp, kp, vp = tensor_split(qkv, config) for to_name, param in zip((q, k, v), (qp, kp, vp)): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param elif "transformer.h" in name: from_name, number = layer_template(name, 2) to_name = weight_map[from_name] if to_name is None: continue to_name = to_name.format(number) param = load_param(param, name,None) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param else: to_name = weight_map[name] param = load_param(param, name, None) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def tensor_split( param: Union[torch.Tensor, NotYetLoadedTensor], config: Config ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def kstart(start, blen, klen) -> int: """returns start index of keys in batch""" return start + (blen - (klen * 2)) def vstart(start, blen, klen) -> int: """returns start index of values in batch""" return start + blen - klen def vend(start, blen) -> int: """returns last index of values in batch""" return start + blen # num observations nobs = param.shape[0] # batch length blen = nobs // config.n_query_groups # key length in batch klen = config.head_size # value length in batch vlen = config.head_size # the starting index of each new batch starts = range(0, nobs, blen) # the indices to splice on splices = [(s, kstart(s, blen, klen), vstart(s, blen, vlen), vend(s, blen)) for s in starts] qc = () kc = () vc = () for splice in splices: qs, ks, vs, ve = splice qc += (param[qs:ks, :],) kc += (param[ks:vs, :],) vc += (param[vs:ve, :],) q = torch.cat(qc) k = torch.cat(kc) v = torch.cat(vc) return q, k, v def maybe_unwrap_state_dict(lit_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return lit_weights.get("model", lit_weights) def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: weight_names = {wk.split(".")[-1] for wk in lit_weights} # LoRA or QLoRA if any("lora" in wn for wn in weight_names): raise ValueError("Model weights must be merged using `lora.merge_lora_weights()` before conversion.") # adapter v2. adapter_bias will only be in adapter_v2 elif "adapter_bias" in weight_names: raise NotImplementedError("Converting models finetuned with adapter_v2 not yet supported.") # adapter. gating_factor is in adapter and adapter_v2 elif "gating_factor" in weight_names: raise NotImplementedError("Converting models finetuned with adapter not yet supported.") def get_tinyllama_init_hf_config() -> dict: return { "architectures": ["LlamaForCausalLM"], "bos_token_id": 1, "eos_token_id": 2, "hidden_act": "silu", "hidden_size": None, "initializer_range": 0.02, "intermediate_size": None, "max_position_embeddings": None, "model_type": "llama", "num_attention_heads": None, "num_hidden_layers": None, "num_key_value_heads": None, "pretraining_tp": 1, "rms_norm_eps": None, "rope_scaling": None, "tie_word_embeddings": False, "torch_dtype": "float32", "transformers_version": "4.31.0.dev0", "use_cache": True, "vocab_size": None, } def convert_config_lit_to_hf(lit_config_dict: dict) -> dict: lit_hf_mapping = { "block_size": "max_position_embeddings", "vocab_size": "vocab_size", "n_layer": "num_hidden_layers", "n_embd": "hidden_size", "n_head": "num_attention_heads", "n_query_groups": "num_key_value_heads", "intermediate_size": "intermediate_size", "norm_eps": "rms_norm_eps", } hf_config_dict = get_tinyllama_init_hf_config() for lit_key, hf_key in lit_hf_mapping.items(): hf_config_dict[hf_key] = lit_config_dict[lit_key] return hf_config_dict @torch.inference_mode() def convert_lit_checkpoint(*, checkpoint_name: str, out_dir: Path, model_name: str, model_only: bool = True) -> None: config = Config.from_name(model_name) if "falcon" in model_name: copy_fn = partial(copy_weights_falcon, "40b" if config.n_embd == 8192 else "7b") elif config._mlp_class == "LLaMAMLP": copy_fn = partial(copy_weights_llama, config) else: copy_fn = copy_weights_gpt_neox # initialize a new empty state dict to hold our new weights sd = {} # checkpoint_name cannot be hardcoded because there exists different outputs such as # ("lit_model_finetuned.pth", "lit_model_lora_finetuned.pth", "lit_model_adapter_finetuned.pth"") pth_file = out_dir / checkpoint_name bin_file = pth_file.with_suffix(".bin") with incremental_save(bin_file) as saver: with contextlib.ExitStack() as stack: lit_weights = stack.enter_context(lazy_load(pth_file)) lit_weights = maybe_unwrap_state_dict(lit_weights) check_conversion_supported(lit_weights) # Incremental save will trigger error copy_fn(sd, lit_weights, saver=None) gc.collect() saver.save(sd) # convert lit config file to hf-style if not model_only: print('Converting config file...') lit_config = asdict(config) hf_config = convert_config_lit_to_hf(lit_config) config_path = out_dir / "config.json" with open(config_path, "w") as f: json.dump(hf_config, f, indent=4) if __name__ == "__main__": from jsonargparse import CLI CLI(convert_lit_checkpoint, as_positional=False) ================================================ FILE: scripts/prepare_redpajama.py ================================================ import glob import json import os import sys from pathlib import Path import numpy as np from tqdm import tqdm # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) import lit_gpt.packed_dataset as packed_dataset from lit_gpt import Config, Tokenizer filenames_sample = [ "arxiv_sample.jsonl", "book_sample.jsonl", "c4_sample.jsonl", "cc_2019-30_sample.jsonl", "cc_2020-05_sample.jsonl", "cc_2021-04_sample.jsonl", "cc_2022-05_sample.jsonl", "cc_2023-06_sample.jsonl", "github_sample.jsonl", "stackexchange_sample.jsonl", "wikipedia_sample.jsonl", ] filename_sets = { "arxiv": "arxiv/arxiv*", "book": "book/book*", "c4": "c4/c4-train*", "common_crawl": "common_crawl/*", "github": "github/filtered*", "stackexchange": "stackexchange/stackexchange*", "wikipedia": "wikipedia/wiki*", } def prepare_sample( source_path: Path, checkpoint_dir: Path, destination_path: Path, chunk_size: int, match: str = "" ) -> None: """Prepare the "Red Pajama" dataset using the original tokenizer.""" destination_path.mkdir(parents=True, exist_ok=True) tokenizer = Tokenizer(checkpoint_dir) for name in filenames_sample: if match and match not in name: continue filepath = source_path / name if not filepath.is_file(): raise RuntimeError( f"Input file not found at {filepath}. \nMake sure you download the data, e.g. wget -i" " https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through" " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T" " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" ) prefix, _ = os.path.splitext(name) builder = packed_dataset.PackedDatasetBuilder( outdir=destination_path, prefix=prefix, chunk_size=chunk_size, sep_token=tokenizer.eos_id, dtype="auto", vocab_size=tokenizer.vocab_size, ) print(f"Processing {name}") with open(filepath, encoding="utf-8") as f: for row in tqdm(f): text = json.loads(row)["text"] text_ids = tokenizer.encode(text) builder.add_array(np.array(text_ids, dtype=builder.dtype)) builder.write_reminder() def prepare_full( source_path: Path, checkpoint_dir: Path, destination_path: Path, chunk_size: int, match: str = "" ) -> None: """Prepare the "Red Pajama" dataset using the original tokenizer.""" import zstandard as zstd destination_path.mkdir(parents=True, exist_ok=True) tokenizer = Tokenizer(checkpoint_dir) for set_name, pattern in filename_sets.items(): if match and match not in set_name: continue is_cc = set_name == "common_crawl" filenames = glob.glob(os.path.join(source_path, pattern), recursive=True) if not filenames: raise RuntimeError( f"No files matching {pattern} found at {source_path}. \nMake sure you download the data, e.g. wget -i" " https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through" " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T" " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" ) builder = packed_dataset.PackedDatasetBuilder( outdir=destination_path, prefix=set_name, chunk_size=chunk_size, sep_token=tokenizer.eos_id, dtype="auto", vocab_size=tokenizer.vocab_size, ) for name in filenames: filepath = source_path / name print(f"Processing {name}") if is_cc: with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f: for row in tqdm(f): text = json.loads(row)["text"] text_ids = tokenizer.encode(text) builder.add_array(np.array(text_ids, dtype=builder.dtype)) else: with open(filepath, encoding="utf-8") as f: for row in tqdm(f): text = json.loads(row)["text"] text_ids = tokenizer.encode(text) builder.add_array(np.array(text_ids, dtype=builder.dtype)) builder.write_reminder() def prepare( source_path: Path = Path("data/RedPajama-Data-1T-Sample"), checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), destination_path: Path = Path("data/redpajama_sample"), sample: bool = True, match: str = "", ) -> None: """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained.""" with open(checkpoint_dir / "lit_config.json") as fp: config = Config(**json.load(fp)) prepare_fn = prepare_sample if sample else prepare_full prepare_fn( source_path=source_path, checkpoint_dir=checkpoint_dir, destination_path=destination_path, chunk_size=(config.block_size + 1) * 1024, # block size + 1 for causal, 1024 blocks match=match, ) if __name__ == "__main__": from jsonargparse import CLI CLI(prepare) ================================================ FILE: scripts/prepare_slimpajama.py ================================================ import json import glob import os from pathlib import Path import sys from typing import List import numpy as np from tqdm import tqdm from multiprocessing import Process, cpu_count # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) import lit_gpt.packed_dataset as packed_dataset from lit_gpt import Tokenizer # Filename for SlimPajama slimpajama_sets = { "train": "train/chunk*/*", "validation": "validation/chunk*/*", "test": "test/chunk*/*", } def prepare_full( source_path: Path, tokenizer_path: Path, destination_path: Path, chunk_size: int, split: str="train", filenames_subset: List[str] = None, process_id: int = 0 ) -> None: import zstandard as zstd destination_path.mkdir(parents=True, exist_ok=True) tokenizer = Tokenizer(tokenizer_path) # Use the provided filenames_subset or default to all filenames filenames = filenames_subset if not filenames: raise RuntimeError( f"No files matching {slimpajama_sets[split]} found at {source_path}. \n" "Make sure you download the data..." ) builder = packed_dataset.PackedDatasetBuilder( outdir=destination_path, prefix=f"{split}_slimpajama_{process_id}", # Use process_id to differentiate builders chunk_size=chunk_size, sep_token=tokenizer.bos_id, dtype="auto", vocab_size=tokenizer.vocab_size, ) for filepath in filenames: print(f"Processing {filepath}") with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f: for row in tqdm(f): text = json.loads(row)["text"] if json.loads(row)["meta"]["redpajama_set_name"] == "RedPajamaGithub": continue # we don't want to include the github data text_ids = tokenizer.encode(text) builder.add_array(np.array(text_ids, dtype=builder.dtype)) # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details # builder.write_reminder() def prepare( source_path: Path = Path("data/RedPajama-Data-1T-Sample"), tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), destination_path: Path = Path("data/red_pajama_sample"), chunk_size: int = 2049 * 1024, split: str="train", percentage: float = 1.0, ) -> None: import time filenames = glob.glob(os.path.join(source_path, slimpajama_sets[split]), recursive=True) filenames = filenames[:int(len(filenames) * percentage)] num_processes = cpu_count() chunked_filenames = np.array_split(filenames, num_processes) processes = [] start_time = time.time() for i, subset in enumerate(chunked_filenames): p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) processes.append(p) p.start() for p in processes: p.join() end_time = time.time() elapsed_time = end_time - start_time print(f"Time taken: {elapsed_time:.2f} seconds") if __name__ == "__main__": from jsonargparse import CLI CLI(prepare) ================================================ FILE: scripts/prepare_starcoder.py ================================================ import json import glob import os from pathlib import Path import sys from typing import List import numpy as np from tqdm import tqdm from multiprocessing import Process, cpu_count # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) import lit_gpt.packed_dataset as packed_dataset from lit_gpt import Tokenizer import pandas as pd def prepare_full( source_path: Path, tokenizer_path: Path, destination_path: Path, chunk_size: int, split: str="train", filenames_subset: List[str] = None, process_id: int = 0 ) -> None: import zstandard as zstd destination_path.mkdir(parents=True, exist_ok=True) tokenizer = Tokenizer(tokenizer_path) # Use the provided filenames_subset or default to all filenames filenames = filenames_subset if not filenames: raise RuntimeError( f"No files matching found at {source_path}. \n" "Make sure you download the data..." ) builder = packed_dataset.PackedDatasetBuilder( outdir=destination_path, prefix=f"{split}_starcoder_{process_id}", # Use process_id to differentiate builders chunk_size=chunk_size, sep_token=tokenizer.bos_id, dtype="auto", vocab_size=tokenizer.vocab_size, ) for filepath in filenames: print(f"Processing {filepath}") try: contents = pd.read_parquet(filepath, engine='pyarrow')['content'] except: print(f"Error reading {filepath}!!") continue for text in contents: text_ids = tokenizer.encode(text) builder.add_array(np.array(text_ids, dtype=builder.dtype)) # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details # builder.write_reminder() def prepare( source_path: Path = Path("data/RedPajama-Data-1T-Sample"), tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), destination_path: Path = Path("data/red_pajama_sample"), chunk_size: int = 2049 * 1024, split: str="train", percentage: float = 1.0, filenames_subset: List[str] = None, ) -> None: import time assert split == "train" # starcoder only has train data filenames = glob.glob(os.path.join(source_path, "*/*.parquet"), recursive=True) # only retrain subsets that follow the prefix in filenames_subset if filenames_subset: filenames = [f for f in filenames if any([prefix in f for prefix in filenames_subset])] filenames = filenames[:int(len(filenames) * percentage)] num_processes = 64 chunked_filenames = np.array_split(filenames, num_processes) processes = [] start_time = time.time() for i, subset in enumerate(chunked_filenames): p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) processes.append(p) p.start() for p in processes: p.join() end_time = time.time() elapsed_time = end_time - start_time print(f"Time taken: {elapsed_time:.2f} seconds") if __name__ == "__main__": from jsonargparse import CLI CLI(prepare) ================================================ FILE: sft/finetune.py ================================================ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from collections import defaultdict import copy import json import os from os.path import exists, join, isdir from dataclasses import dataclass, field import sys from typing import Optional, Dict, Sequence import numpy as np from tqdm import tqdm import logging import pandas as pd import importlib from packaging import version from packaging.version import parse import torch import transformers from torch.nn.utils.rnn import pad_sequence import argparse from transformers import ( AutoTokenizer, AutoModelForCausalLM, set_seed, Seq2SeqTrainer, BitsAndBytesConfig, LlamaTokenizer ) from datasets import load_dataset, Dataset import evaluate from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True logger = logging.getLogger(__name__) IGNORE_INDEX = -100 DEFAULT_PAD_TOKEN = "[PAD]" @dataclass class ModelArguments: model_name_or_path: Optional[str] = field( default="EleutherAI/pythia-12b" ) trust_remote_code: Optional[bool] = field( default=False, metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."} ) @dataclass class DataArguments: eval_dataset_size: int = field( default=1024, metadata={"help": "Size of validation dataset."} ) max_train_samples: Optional[int] = field( default=None, metadata={ "help": "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." }, ) max_eval_samples: Optional[int] = field( default=None, metadata={ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " "value if set." }, ) source_max_len: int = field( default=1024, metadata={"help": "Maximum source sequence length. Sequences will be right padded (and possibly truncated)."}, ) target_max_len: int = field( default=256, metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."}, ) dataset: str = field( default='alpaca', metadata={"help": "Which dataset to finetune on. See datamodule for options."} ) dataset_format: Optional[str] = field( default=None, metadata={"help": "Which dataset format is used. [alpaca|chip2|self-instruct|hh-rlhf]"} ) @dataclass class TrainingArguments(transformers.Seq2SeqTrainingArguments): train_on_source: Optional[bool] = field( default=False, metadata={"help": "Whether to train on the input in addition to the target text."} ) report_to: str = field( default='none', metadata={"help": "To use wandb or something else for reporting."} ) output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'}) optim: str = field(default='adamw_torch', metadata={"help": 'The optimizer to be used'}) per_device_train_batch_size: int = field(default=16, metadata={"help": 'The training batch size per GPU. Increase for better speed.'}) gradient_accumulation_steps: int = field(default=1, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'}) max_steps: int = field(default=10000, metadata={"help": 'How many optimizer update steps to take'}) weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'}) learning_rate: float = field(default=0.0002, metadata={"help": 'The learnign rate'}) remove_unused_columns: bool = field(default=False, metadata={"help": 'Removed unused columns. Needed to make this codebase work.'}) max_grad_norm: float = field(default=0.3, metadata={"help": 'Gradient clipping max norm. This is tuned and works well for all models tested.'}) gradient_checkpointing: bool = field(default=True, metadata={"help": 'Use gradient checkpointing. You want to use this.'}) do_train: bool = field(default=True, metadata={"help": 'To train or not to train, that is the question?'}) lr_scheduler_type: str = field(default='constant', metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'}) warmup_ratio: float = field(default=0.03, metadata={"help": 'Fraction of steps to do a warmup for'}) logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'}) group_by_length: bool = field(default=True, metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'}) save_strategy: str = field(default='steps', metadata={"help": 'When to save checkpoints'}) save_steps: int = field(default=250, metadata={"help": 'How often to save a model'}) save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'}) @dataclass class GenerationArguments: # For more hyperparameters check: # https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig # Length arguments max_new_tokens: Optional[int] = field( default=256, metadata={"help": "Maximum number of new tokens to be generated in evaluation or prediction loops" "if predict_with_generate is set."} ) min_new_tokens : Optional[int] = field( default=None, metadata={"help": "Minimum number of new tokens to generate."} ) # Generation strategy do_sample: Optional[bool] = field(default=False) num_beams: Optional[int] = field(default=1) num_beam_groups: Optional[int] = field(default=1) penalty_alpha: Optional[float] = field(default=None) use_cache: Optional[bool] = field(default=True) # Hyperparameters for logit manipulation temperature: Optional[float] = field(default=1.0) top_k: Optional[int] = field(default=50) top_p: Optional[float] = field(default=1.0) typical_p: Optional[float] = field(default=1.0) diversity_penalty: Optional[float] = field(default=0.0) repetition_penalty: Optional[float] = field(default=1.0) length_penalty: Optional[float] = field(default=1.0) no_repeat_ngram_size: Optional[int] = field(default=0) def get_accelerate_model(args, checkpoint_dir): device_map = "auto" # if we are in a distributed setting, we need to set the device map and max memory per device if os.environ.get('LOCAL_RANK') is not None: local_rank = int(os.environ.get('LOCAL_RANK', '0')) device_map = {'': local_rank} print(f'loading base model {args.model_name_or_path}...') model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, device_map=device_map, trust_remote_code=args.trust_remote_code, ) # Tokenizer tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path, padding_side="right", use_fast=True, # Fast tokenizer giving issues. trust_remote_code=args.trust_remote_code, ) if tokenizer._pad_token is None: special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN) if args.dataset == "OpenAssistant/oasst_top1_2023-08-25": chat_special_tokens = ["<|im_start|>", "<|im_end|>"] special_tokens_dict.update(additional_special_tokens=chat_special_tokens) smart_tokenizer_and_embedding_resize( special_tokens_dict=special_tokens_dict, tokenizer=tokenizer, model=model ) return model, tokenizer def print_trainable_parameters(args, model): """ Prints the number of trainable parameters in the model. """ trainable_params = 0 all_param = 0 for _, param in model.named_parameters(): all_param += param.numel() if param.requires_grad: trainable_params += param.numel() print( f"trainable params: {trainable_params} || " f"all params: {all_param} || " ) def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, non_special_tokens = None, ): """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + tokenizer.add_tokens(non_special_tokens) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings_data = model.get_input_embeddings().weight.data output_embeddings_data = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True) output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True) input_embeddings_data[-num_new_tokens:] = input_embeddings_avg output_embeddings_data[-num_new_tokens:] = output_embeddings_avg print(f"Resized tokenizer and embedding to {len(tokenizer)} tokens.") @dataclass class DataCollatorForCausalLM(object): tokenizer: transformers.PreTrainedTokenizer source_max_len: int target_max_len: int train_on_source: bool predict_with_generate: bool def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: # Extract elements sources = [f"{self.tokenizer.bos_token}{example['input']}" for example in instances] targets = [f"{example['output']}{self.tokenizer.eos_token}" for example in instances] # Tokenize tokenized_sources_with_prompt = self.tokenizer( sources, max_length=self.source_max_len, truncation=True, add_special_tokens=False, ) tokenized_targets = self.tokenizer( targets, max_length=self.target_max_len, truncation=True, add_special_tokens=False, ) # Build the input and labels for causal LM input_ids = [] labels = [] for tokenized_source, tokenized_target in zip( tokenized_sources_with_prompt['input_ids'], tokenized_targets['input_ids'] ): if not self.predict_with_generate: input_ids.append(torch.tensor(tokenized_source + tokenized_target)) if not self.train_on_source: labels.append( torch.tensor([IGNORE_INDEX for _ in range(len(tokenized_source))] + copy.deepcopy(tokenized_target)) ) else: labels.append(torch.tensor(copy.deepcopy(tokenized_source + tokenized_target))) else: input_ids.append(torch.tensor(tokenized_source)) # Apply padding input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) if not self.predict_with_generate else None data_dict = { 'input_ids': input_ids, 'attention_mask':input_ids.ne(self.tokenizer.pad_token_id), } if labels is not None: data_dict['labels'] = labels return data_dict def extract_unnatural_instructions_data(examples, extract_reformulations=False): out = { 'input': [], 'output': [], } for example_instances in examples['instances']: for instance in example_instances: out['input'].append(instance['instruction_with_input']) out['output'].append(instance['output']) if extract_reformulations: for example_reformulations in examples['reformulations']: if example_reformulations is not None: for instance in example_reformulations: out['input'].append(instance['instruction_with_input']) out['output'].append(instance['output']) return out ALPACA_PROMPT_DICT = { "prompt_input": ( "Below is an instruction that describes a task, paired with an input that provides further context. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: " ), "prompt_no_input": ( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Response: " ), } def extract_alpaca_dataset(example): if example.get("input", "") != "": prompt_format = ALPACA_PROMPT_DICT["prompt_input"] else: prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"] return {'input': prompt_format.format(**example)} def local_dataset(dataset_name): if dataset_name.endswith('.json') or dataset_name.endswith('.jsonl'): full_dataset = Dataset.from_json(path_or_paths=dataset_name) elif dataset_name.endswith('.csv'): full_dataset = Dataset.from_pandas(pd.read_csv(dataset_name)) elif dataset_name.endswith('.tsv'): full_dataset = Dataset.from_pandas(pd.read_csv(dataset_name, delimiter='\t')) else: raise ValueError(f"Unsupported dataset format: {dataset_name}") split_dataset = full_dataset.train_test_split(test_size=0.1) return split_dataset def make_data_module(tokenizer: transformers.PreTrainedTokenizer, args) -> Dict: """ Make dataset and collator for supervised fine-tuning. Datasets are expected to have the following columns: { `input`, `output` } Available datasets to be selected with `dataset` argument: - alpaca, 52002 examples - alpaca cleaned, 51942 examples - chip2 (OIG), 210289 examples - self-instruct, 82612 examples - hh-rlhf (Anthropic), 160800 examples - longform, 23.7k examples - oasst1 (OpenAssistant) primary message tree only, 9,846 examples Coming soon: - unnatural instructions core, 66010 examples - unnatural instructions full, 240670 examples - alpaca-gpt4, 52002 examples - unnatural-instructions-gpt4, 9000 examples - supernatural-instructions, 69624 examples (same as paper with 100 ex/task more can be used) - flan (FLAN v2), up to 20M examples available - vicuna """ def load_data(dataset_name): if dataset_name == 'alpaca': return load_dataset("tatsu-lab/alpaca") elif dataset_name == 'alpaca-clean': return load_dataset("yahma/alpaca-cleaned") elif dataset_name == 'chip2': return load_dataset("laion/OIG", data_files='unified_chip2.jsonl') elif dataset_name == 'hh-rlhf': return load_dataset("Anthropic/hh-rlhf") elif dataset_name == 'longform': return load_dataset("akoksal/LongForm") elif dataset_name == 'oasst1': return load_dataset("timdettmers/openassistant-guanaco") elif dataset_name == "OpenAssistant/oasst_top1_2023-08-25": return load_dataset("OpenAssistant/oasst_top1_2023-08-25") elif dataset_name == 'vicuna': raise NotImplementedError("Vicuna data was not released.") else: if os.path.exists(dataset_name): try: args.dataset_format = args.dataset_format if args.dataset_format else "input-output" full_dataset = local_dataset(dataset_name) return full_dataset except: raise ValueError(f"Error loading dataset from {dataset_name}") else: raise NotImplementedError(f"Dataset {dataset_name} not implemented yet.") def format_dataset(dataset, dataset_format): if ( dataset_format == 'alpaca' or dataset_format == 'alpaca-clean' or (dataset_format is None and args.dataset in ['alpaca', 'alpaca-clean']) ): dataset = dataset.map(extract_alpaca_dataset, remove_columns=['instruction']) elif dataset_format == 'chip2' or (dataset_format is None and args.dataset == 'chip2'): dataset = dataset.map(lambda x: { 'input': x['text'].split('\n: ')[0].replace(' : ', ''), 'output': x['text'].split('\n : ')[1], }) elif dataset_format == 'self-instruct' or (dataset_format is None and args.dataset == 'self-instruct'): for old, new in [["prompt", "input"], ["completion", "output"]]: dataset = dataset.rename_column(old, new) elif dataset_format == 'hh-rlhf' or (dataset_format is None and args.dataset == 'hh-rlhf'): dataset = dataset.map(lambda x: { 'input': '', 'output': x['chosen'] }) elif dataset_format == 'oasst1' or (dataset_format is None and args.dataset == 'oasst1'): dataset = dataset.map(lambda x: { 'input': '', 'output': x['text'], }) elif dataset_format == 'input-output': # leave as is pass # Remove unused columns. dataset = dataset.remove_columns( [col for col in dataset.column_names['train'] if col not in ['input', 'output']] ) return dataset # Load dataset. dataset = load_data(args.dataset) dataset = format_dataset(dataset, args.dataset_format) # Split train/eval, reduce size if args.do_eval or args.do_predict: if 'eval' in dataset: eval_dataset = dataset['eval'] else: print('Splitting train dataset in train and validation according to `eval_dataset_size`') dataset = dataset["train"].train_test_split( test_size=args.eval_dataset_size, shuffle=True, seed=42 ) eval_dataset = dataset['test'] if args.max_eval_samples is not None and len(eval_dataset) > args.max_eval_samples: eval_dataset = eval_dataset.select(range(args.max_eval_samples)) if args.group_by_length: eval_dataset = eval_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])}) if args.do_train: train_dataset = dataset['train'] if args.max_train_samples is not None and len(train_dataset) > args.max_train_samples: train_dataset = train_dataset.select(range(args.max_train_samples)) if args.group_by_length: train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])}) data_collator = DataCollatorForCausalLM( tokenizer=tokenizer, source_max_len=args.source_max_len, target_max_len=args.target_max_len, train_on_source=args.train_on_source, predict_with_generate=args.predict_with_generate, ) return dict( train_dataset=train_dataset if args.do_train else None, eval_dataset=eval_dataset if args.do_eval else None, predict_dataset=eval_dataset if args.do_predict else None, data_collator=data_collator ) def get_last_checkpoint(checkpoint_dir): if isdir(checkpoint_dir): is_completed = exists(join(checkpoint_dir, 'completed')) if is_completed: return None, True # already finished max_step = 0 for filename in os.listdir(checkpoint_dir): if isdir(join(checkpoint_dir, filename)) and filename.startswith('checkpoint'): max_step = max(max_step, int(filename.replace('checkpoint-', ''))) if max_step == 0: return None, is_completed # training started, but no checkpoint checkpoint_dir = join(checkpoint_dir, f'checkpoint-{max_step}') print(f"Found a previous checkpoint at: {checkpoint_dir}") return checkpoint_dir, is_completed # checkpoint found! return None, False # first training def train(): hfparser = transformers.HfArgumentParser(( ModelArguments, DataArguments, TrainingArguments, GenerationArguments )) model_args, data_args, training_args, generation_args, extra_args = \ hfparser.parse_args_into_dataclasses(return_remaining_strings=True) training_args.generation_config = transformers.GenerationConfig(**vars(generation_args)) args = argparse.Namespace( **vars(model_args), **vars(data_args), **vars(training_args) ) print(args) checkpoint_dir, completed_training = get_last_checkpoint(args.output_dir) if completed_training: print('Detected that training was already completed!') model, tokenizer = get_accelerate_model(args, checkpoint_dir) model.config.use_cache = False print('loaded model') set_seed(args.seed) data_module = make_data_module(tokenizer=tokenizer, args=args) trainer = Seq2SeqTrainer( model=model, tokenizer=tokenizer, args=training_args, **{k:v for k,v in data_module.items() if k != 'predict_dataset'}, ) # Verifying the datatypes and parameter counts before training. print_trainable_parameters(args, model) dtypes = {} for _, p in model.named_parameters(): dtype = p.dtype if dtype not in dtypes: dtypes[dtype] = 0 dtypes[dtype] += p.numel() total = 0 for k, v in dtypes.items(): total+= v for k, v in dtypes.items(): print(k, v, v/total) all_metrics = {"run_name": args.run_name} # Training if args.do_train: logger.info("*** Train ***") # Note: `resume_from_checkpoint` not supported for adapter checkpoints by HF. # Currently adapter checkpoint is reloaded as expected but optimizer/scheduler states are not. train_result = trainer.train() metrics = train_result.metrics trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() all_metrics.update(metrics) # Evaluation if args.do_eval: logger.info("*** Evaluate ***") metrics = trainer.evaluate(metric_key_prefix="eval") trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) all_metrics.update(metrics) # Prediction if args.do_predict: logger.info("*** Predict ***") prediction_output = trainer.predict(test_dataset=data_module['predict_dataset'],metric_key_prefix="predict") prediction_metrics = prediction_output.metrics predictions = prediction_output.predictions predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id) predictions = tokenizer.batch_decode( predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True ) with open(os.path.join(args.output_dir, 'predictions.jsonl'), 'w') as fout: for i, example in enumerate(data_module['predict_dataset']): example['prediction_with_input'] = predictions[i].strip() example['prediction'] = predictions[i].replace(example['input'], '').strip() fout.write(json.dumps(example) + '\n') print(prediction_metrics) trainer.log_metrics("predict", prediction_metrics) trainer.save_metrics("predict", prediction_metrics) all_metrics.update(prediction_metrics) if (args.do_train or args.do_eval or args.do_predict): with open(os.path.join(args.output_dir, "metrics.json"), "w") as fout: fout.write(json.dumps(all_metrics)) if __name__ == "__main__": train() ================================================ FILE: sft/script.sh ================================================ # We include a simple full-parameter finetuning & inference script here. Our V0.1 chat model is finetuned using this script. # The FT dataset we use is openassistant-guanaco. For finetuning with less than 4GB RAM, we refer you to the Qlora and bitsandbytes repo. # We did not undergone extensive hyperparameter tuning nor choosing more performant FT datasets. # We hope the community can explore on finetuning TinyLlama and come up with better chat models. I will include community-finetuned models in this repo. # V0.1 CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --multi_gpu --num_processes 4 --main_process_port 1234 finetune.py \ --model_name_or_path PY007/TinyLlama-1.1B-intermediate-step-240k-503b \ --output_dir ./output/503B_FT_lr1e-5_ep5 \ --logging_steps 10 \ --save_strategy epoch \ --data_seed 42 \ --save_total_limit 6 \ --evaluation_strategy epoch \ --eval_dataset_size 512 \ --max_eval_samples 1000 \ --per_device_eval_batch_size 1 \ --max_new_tokens 32 \ --dataloader_num_workers 3 \ --group_by_length=False \ --logging_strategy steps \ --remove_unused_columns False \ --do_train \ --do_eval \ --warmup_ratio 0.05 \ --lr_scheduler_type constant \ --dataset oasst1 \ --source_max_len 16 \ --target_max_len 512 \ --per_device_train_batch_size 4 \ --max_steps 0 \ --num_train_epochs 5 \ --learning_rate 1e-5 \ --adam_beta2 0.999 \ --max_grad_norm 1.0 \ --weight_decay 0.0 \ --seed 0 \ --trust_remote_code \ --report_to wandb # V0.2 CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --multi_gpu --num_processes 4 --main_process_port 1234 finetune.py \ --model_name_or_path PY007/TinyLlama-1.1B-intermediate-step-480k-1T \ --output_dir ./output/503B_FT_lr1e-5_ep5_top1_2023-08-25 \ --logging_steps 10 \ --save_strategy epoch \ --data_seed 42 \ --save_total_limit 6 \ --evaluation_strategy epoch \ --eval_dataset_size 512 \ --max_eval_samples 1000 \ --per_device_eval_batch_size 1 \ --max_new_tokens 32 \ --dataloader_num_workers 3 \ --group_by_length=False \ --logging_strategy steps \ --remove_unused_columns False \ --do_train \ --do_eval \ --warmup_ratio 0.05 \ --lr_scheduler_type constant \ --dataset OpenAssistant/oasst_top1_2023-08-25 \ --dataset_format oasst1 \ --source_max_len 16 \ --target_max_len 512 \ --per_device_train_batch_size 4 \ --max_steps 0 \ --num_train_epochs 5 \ --learning_rate 1e-5 \ --adam_beta2 0.999 \ --max_grad_norm 1.0 \ --weight_decay 0.0 \ --seed 0 \ --trust_remote_code \ --report_to wandb ================================================ FILE: sft/simple_inference.py ================================================ from transformers import AutoTokenizer import transformers import torch model = "PY007/TinyLlama-1.1B-Chat-v0.1" tokenizer = AutoTokenizer.from_pretrained(model) pipeline = transformers.pipeline( "text-generation", model=model, torch_dtype=torch.float16, device_map="auto", ) prompt = "Give me detailed info about Jeo Biden." formatted_prompt = ( f"### Human: {prompt} ### Assistant:" ) sequences = pipeline( formatted_prompt, do_sample=True, top_k=50, top_p = 0.9, num_return_sequences=1, repetition_penalty=1.1, max_new_tokens=1024, ) for seq in sequences: print(f"Result: {seq['generated_text']}") ================================================ FILE: sft/simple_inference2.py ================================================ from transformers import AutoTokenizer import transformers import torch model = "PY007/TinyLlama-1.1B-Chat-v0.2" tokenizer = AutoTokenizer.from_pretrained(model) pipeline = transformers.pipeline( "text-generation", model=model, torch_dtype=torch.float16, device_map="auto", ) prompt = "How to get in a good university?" formatted_prompt = ( f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" ) sequences = pipeline( formatted_prompt, do_sample=True, top_k=50, top_p = 0.9, num_return_sequences=1, repetition_penalty=1.1, max_new_tokens=1024, ) for seq in sequences: print(f"Result: {seq['generated_text']}") ================================================ FILE: speculative_decoding/README.md ================================================ ## Speculative Decoding ### HuggingFace "Assisted Generation" | Large Model | Native Decoding | Assisted Decoding | | ----------- | --------------- | ------------------ | | guanaco-7b | 69 seconds | 38 seconds | | guanaco-13b | 84 seconds | 45 seconds | | guanaco-33b | 109 seconds | 62 seconds | We use PY007/TinyLlama-1.1B-Chat-v0.1 as the assistant model and vary the large model from guanaco-7B to 33B. Experiments are done on a single A40 GPU with code inside instruct_hf_assisted_decoding.py. TinyLlama is loaded in fp16 and the large models are loaded in 8 bit to make guanaco-33b fit in memory and also to keep a consistent setup. The prompt used is "Give me detailed info about Jeo Biden.". max_new_tokens is set to 512. You can read this [article](https://huggingface.co/blog/assisted-generation) for more information about HuggingFace's Assisted Generation. Quote from HF: "due to INT8 quantization and the use of causal masking in assisted generation, the output of greedy decoding may differ in rare occasions." #### TODO - [ ] Thouroughly benchmark the average speedup on 52K Alpaca prompts. ### Llama.cpp Speculative Decoding We have continue-pretrained a code tinyllama from the 500B checkpoint with another 7B Python data [here](https://huggingface.co/PY007/TinyLlama-1.1B-python-v0.1). The code for continue-pretraining can be found in pretrain/tinyllama_code.py ``` ./speculative \ -m models/CodeLlama-7b-hf/ggml-model-f16.gguf \ -md models/TinyLlama-1.1B-500B-python/ggml-model-q4_0.gguf \ -p "# Quick-sort implementation in Python and sample usage:" \ -e -ngl 1 -t 4 -n 256 -s 20 --temp 0 --draft 8 ``` This gives: ``` encoded 12 tokens in 0.247 seconds, speed: 48.638 t/s decoded 265 tokens in 7.909 seconds, speed: 33.507 t/s n_draft = 16 n_predict = 265 n_drafted = 317 n_accept = 195 accept = 61.514% draft: llama_print_timings: load time = 53.14 ms llama_print_timings: sample time = 652.62 ms / 1 runs ( 652.62 ms per token, 1.53 tokens per second) llama_print_timings: prompt eval time = 73.81 ms / 12 tokens ( 6.15 ms per token, 162.58 tokens per second) llama_print_timings: eval time = 2247.77 ms / 378 runs ( 5.95 ms per token, 168.17 tokens per second) llama_print_timings: total time = 8154.92 ms target: llama_print_timings: load time = 534.47 ms llama_print_timings: sample time = 208.12 ms / 265 runs ( 0.79 ms per token, 1273.32 tokens per second) llama_print_timings: prompt eval time = 4210.38 ms / 382 tokens ( 11.02 ms per token, 90.73 tokens per second) llama_print_timings: eval time = 682.80 ms / 16 runs ( 42.68 ms per token, 23.43 tokens per second) llama_print_timings: total time = 8214.11 ms ggml_metal_free: deallocating ggml_metal_free: deallocating ``` Even though the model is continue-pretrained exclusively on Python, it retains its ability in other languages, such as C: ``` ./speculative \ -m models/CodeLlama-7b-hf/ggml-model-f16.gguf \ -md models/TinyLlama-1.1B-500B-python/ggml-model-q4_0.gguf \ -p "// Quick-sort implementation in C (4 spaces indentation + detailed comments) and sample usage:\n\n#include" \ -e -ngl 1 -t 4 -n 256 -s 20 --temp 0 --draft 8 ``` This gives: ``` encoded 25 tokens in 0.278 seconds, speed: 89.900 t/s decoded 258 tokens in 6.432 seconds, speed: 40.112 t/s n_draft = 28 n_predict = 258 n_drafted = 278 n_accept = 200 accept = 71.942% draft: llama_print_timings: load time = 932.54 ms llama_print_timings: sample time = 583.50 ms / 1 runs ( 583.50 ms per token, 1.71 tokens per second) llama_print_timings: prompt eval time = 81.50 ms / 25 tokens ( 3.26 ms per token, 306.73 tokens per second) llama_print_timings: eval time = 1834.67 ms / 329 runs ( 5.58 ms per token, 179.32 tokens per second) llama_print_timings: total time = 6710.30 ms target: llama_print_timings: load time = 18568.44 ms llama_print_timings: sample time = 208.78 ms / 258 runs ( 0.81 ms per token, 1235.75 tokens per second) llama_print_timings: prompt eval time = 3164.84 ms / 342 tokens ( 9.25 ms per token, 108.06 tokens per second) llama_print_timings: eval time = 775.43 ms / 18 runs ( 43.08 ms per token, 23.21 tokens per second) llama_print_timings: total time = 7650.67 ms ggml_metal_free: deallocating ggml_metal_free: deallocating ``` I have not tried 13B CodeLlama as the large model yet because my Mac memory is not enough :). ================================================ FILE: speculative_decoding/instruct_hf_assisted_decoding.py ================================================ from transformers import AutoModelForCausalLM, AutoTokenizer import torch import time model_id = "huggyllama/llama-13b" peft_model_id = "timdettmers/guanaco-13b" assistant_checkpoint = "PY007/TinyLlama-1.1B-Chat-v0.1" device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(model_id) prompt = "Give me detailed info about Jeo Biden." formatted_prompt = f"### Human: {prompt}### Assistant:" inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device) model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True) model.load_adapter(peft_model_id) print("Large model loaded") model.config.use_cache = True assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint).half().to(device) assistant_model.config.use_cache = True print("Small model loaded") print("###Native Decoding Starts...\n") start = time.time() outputs = model.generate(**inputs, assistant_model=None, max_new_tokens=512) end = time.time() print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) print("Time: ", end - start) print("###TinyLlama Assisted Decoding Starts...\n") start = time.time() outputs = model.generate(**inputs, assistant_model=assistant_model,max_new_tokens=512) end = time.time() print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) # print time in seconds print("Time: ", end - start)