Repository: tdrussell/qlora-pipe Branch: main Commit: 90b8e7d62034 Files: 30 Total size: 232.7 KB Directory structure: gitextract_z9_xsnea/ ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── examples/ │ ├── alpaca.yml │ ├── capybara.yml │ ├── config.toml │ ├── config_dpo.toml │ ├── converted_dpo_dataset.yml │ ├── ds_config.json │ └── ultrafeedback.yml ├── kernels/ │ ├── cross_entropy_loss.py │ └── utils.py ├── models/ │ ├── layers.py │ ├── models.py │ └── pipeline_model.py ├── pyproject.toml ├── requirements.txt ├── tools/ │ ├── convert_dpo_dataset_to_chat_format.py │ ├── convert_ds_checkpoint_to_lora.py │ ├── merge_lora.py │ └── test_sampling.py ├── train.py └── utils/ ├── dataloader.py ├── dataset_utils.py ├── engine.py ├── hqq_utils.py ├── saver.py ├── unsloth_utils.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ ================================================ FILE: .gitmodules ================================================ [submodule "axolotl"] path = axolotl url = https://github.com/OpenAccess-AI-Collective/axolotl/ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2023 tdrussell Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # qlora-pipe A pipeline parallel training script for LLMs. Refer to the changelog at the bottom for details on updates. ## About This is a training script I made so that I can fine-tune LLMs using my workstation with four 4090s. It is developed first and foremost for myself, with my own use cases in mind. It is scrappy and hacked together. It will likely *never* be a stable, well-supported training script like Axolotl. I am open sourcing the code in case it is useful to others, and also as a proof-of-concept that this kind of thing is possible. That being said, if something doesn't work right, or you would like it to support some feature, feel free to raise an issue and I'll try to look at it. ## Features - Pipeline parallel training, for efficiently training large models that cannot fit on one GPU - Supports QLoRA, LoRA, and full fine tuning - Quantize weights using either bitsandbytes or HQQ - Efficient model loading. Each process only loads the layers it needs, and quantizes and moves them to the GPU layer-by-layer. This means you can load a large model on a lot of GPUs even with limited system RAM. - Load any dataset that Axolotl can, using the same YAML config file format - Support for "raw text" training using either a structured list of documents in a JSON file, or a single txt file - Support for resuming training from a checkpoint, including the dataloader state, to easily allow training in a piecemeal fashion - Useful metrics logged to Tensorboard - Ability to specify a separate, fixed evaluation dataset - Train on multiple datasets simultaneously, with different sampling ratios per dataset - Models currently supported: Llama, Mistral, Mixtral, Qwen, Cohere (Command R), Phi-3 (mini and medium), Gemma 2, Gemma 3, Cohere2 (Command-A) ## Installing Clone the repository: ``` git clone --recurse-submodules https://github.com/tdrussell/qlora-pipe ``` If you alread cloned it and forgot to do --recurse-submodules: ``` git submodule init git submodule update ``` Install Miniconda: https://docs.conda.io/en/latest/miniconda.html Create the environment ``` conda create -n qlora-pipe python=3.12 conda activate qlora-pipe ``` Install the dependencies: ``` pip install -r requirements.txt ``` Install nvcc: ``` conda install nvidia::cuda-nvcc ``` ## Training __Start by reading through the config files in the examples directory__. There are lots of comments explaining what the various fields do. Then, make a copy and edit it however you like. At minimum, change the paths at the top to point to your model and desired output directory. Launch the training script: ``` NCCL_P2P_DISABLE="1" NCCL_IB_DISABLE="1" deepspeed --num_gpus=1 train.py --deepspeed --config examples/config.toml ``` RTX 4000 series needs those 2 enviroment variables set. Other GPUs may not need them. ## Parallelism Deepspeed handles pipeline- and data-parallelism. Set the --num_gpus flag to however many GPUs to want to use. The config option `pipeline_stages` determines the level of model parallelism. Then, the data parallelism is automatically set so that all GPUs are used. For example with 8 GPUs, and pipeline_stages=4, a single instance of the model is divided across 4 GPUs. Because there are 8 GPUs total, there are then 2 data-parallel instances. The option `gradient_accumulation_steps` in the Deepspeed JSON config file determines the amount of pipelining when using pipeline parallelism (pipeline_stages>1). The higher the value, the more the GPUs can overlap computation. For example, with gradient_accumulation_steps=1, there is a single batch that gets passed between the GPUs forward, then in reverse for the backward pass. Only 1 GPU is active at a time, the others are idle. As gradient_accumulation_steps increases, you start pipelining multiple forward/backward batches. At the beginning and end of the step, some GPUs will always be idle. So as gradient_accumulation_steps approaches infinity, you approach 100% theoretical utilization. In practice, a value of 8 or so already gives good average utilization with 2 GPUs. With more GPUs, you may want to go higher. ## Dataset configuration There are 3 options for specifying each dataset. Set the `dataset_type` field to one of: - axolotl - Loads the dataset using the Axolotl codebase. Set `dataset_path` to a YAML file that contains the same dataset configuration you would use in Axolotl. - doclist - Set `dataset_path` to glob pattern matching one or more JSON or JSONL files. Each file should be a list of objects containing a 'text' key. For each file, all of the text is logically concatenated together, before being sliced into sequences. - textfile - Basically the same as doclist, except the `dataset_path` matches one or more txt files. Each text file is sliced into sequences. You can read dataset_utils.py for details on what each of these options is doing. You can have multiple datasets. Just add additional `[[datasets]]` entries. When using multiple datasets, there are different ways to combine them. - `dataset_combination_mode` = 'concatenate' (the default) - Just concatenates the datasets. - `dataset_combination_mode` = 'interleave' - Uses the Huggingface Datasets library `interleave_datasets()` function. - Use the `dataset_interleave_stopping_strategy` setting to control when interleaving stops. - 'first_exhausted': stop when a dataset runs out of examples. - 'all_exhausted': stop when all datasets have run out of examples. This duplicates examples from smaller datasets. - When using the 'interleave' mode, datasets can have a relative `sample_weight`, which is a positive real number. This controls the relative proportion of the datasets when they are combined. - __IMPORTANT__: When using the 'interleave' mode, the manner in which the datasets are proportionally combined (i.e. sampled from) is affected by the `batch_size_tokens` setting: - If `batch_size_tokens` is unset, it means you are treating each example equally. Every batch has the same number of examples, even though they may be different lengths. So, when interleaving datasets, the rows are sampled according to the relative proportions given by the `sample_weight`. - If using `batch_size_tokens`, it means you are treating each token equally. Every batch varies the number of examples (because they might have different lengths) so that the token count is approximately constant. So, when interleaving datasets, the sampling ratios are adjusted so that the number of *tokens*, not rows, drawn from different datasets matches the `sample_weight`. This is implemented by scaling the sampling probabilities by the average length of the dataset. You can read the `combine_datasets()` function in dataset_utils.py if this is confusing. - __Which of these should I use?__ Probably set `batch_size_tokens`. I think this is the better way to think about things, and it matches what sample packing would do. For example, in Axolotl, it is recommended to use sample packing, which packs multiple examples into a single sequence so that the sequence length is constant. This means, in the loss function, each token is being treated with equal weight, not each original row in the dataset. Using `batch_size_tokens` in this training script mimics that behavior, and thus when interleaving datasets, it samples from them so that the token ratios adhere to the sample_weight specified. - __Example__: you have datasets A and B. B's average row length is twice that of A. A has a sample_weight of 2, B has a sample_weight of 1. - Not setting batch_size_tokens: when interleaving, you get 2 rows of A for every row of B. - Using batch_size_tokens: when interleaving, you get 4 rows of A for every row of B. This is because A's rows are on average half the length of B's rows, so you need twice as many as before so that the number of tokens in each matches the 2:1 ratio you specified with the sample_weight. ## On sample packing (or the lack thereof) Sample packing is not currently implemented. Instead, there is the option `batch_size_tokens`. If this field is set, the per-device batch size is ignored, and instead the batch size is adjusted dynamically to target a fixed number of tokens per batch, per device. This was easier to implement than sample packing, and does basically the same thing. It is also efficient: if I set batch_size_tokens to a modest 10000 and train a 7B model with the Alpaca dataset, all my 4090s hit their 350W power limit cap. Unless I'm missing something (definitely possible), it seems there is no need to support sample packing. ## Floating point precision There are different places you can specify the floating point dtype. `model_weight_dtype` controls the precision of the underlying model weights (for any weights not quantized), and `lora_weight_dtype` is for the lora weights. If you are using quantization, both bnb and hqq have options for the compute dtype as well. If you are using 16 bit dtypes, floating point roundoff error is a potential problem. For a good overview of the problem and solutions, see [Revisiting Bfloat16 Training](https://arxiv.org/pdf/2010.06192). TLDR: the main source of precision error when training with 16 bit weights is the weight update step: $(p = p + \Delta p * lr)$. When the update is very small compared to the parameter (which is often the case), there can be significant roundoff error, including the update being entirely dropped. Mixed precision training solves this by keeping a master copy of the weights in fp32, and running all optimizer steps in fp32. Kahan summation is another solution when training in full bf16, that keeps an extra bf16 buffer for each parameter to accumulate roundoff errors so that updates are never dropped. ### Okay but how should I configure things? - If unsure, set everything to bf16 and use the adamw_kahan optimizer type. Kahan summation is ESPECIALLY important for full fine tuning. Kahan summation requires an extra 2 bytes per trainable parameter compared to vanilla full bf16 training. - For LoRAs, another option is setting `lora_weight_dtype` to fp32, which also makes all optimizer states fp32. - For LoRAs only, with constant learning rate no lower than 5e-5 or so, I have seen full bf16 training with no Kahan summation mostly match fp32 or bf16 + Kahan. - (more experimental) You may try Deepspeed's bf16 mode, but I personally don't use this. I think this does something like mixed precision, where it wraps the optimizer to keep a master copy of the parameters in fp32, as well as doing gradient accumulation and all optimizer states in fp32. This will use much more memory than full bf16 + Kahan summation. ## Changelog ### 2025-03-12 - Change how weights are loaded to avoid Transformers internal method - Support Gemma 3 ### 2025-01-30 - Add pretokenized dataset option. - Update layers to work with the new way to pass position embeddings in HF Transformers. Please update Transformers to the latest version or you will get errors. ### 2024-12-29 - Add DPO training. The examples directory has a DPO example. ### 2024-07-02 - Add Gemma-2 support. ### 2024-06-20 - Add adamw_kahan optimzer type and make it the default in the example. ### 2024-05-19 **The old config file format will break.** Quantization is configured slightly differently now. Read examples/config_7b.toml. It's only a few lines to change. - Change how quantization is configured. Quantization is now its own table in the TOML file. - Add HQQ quantization. ### 2024-04-28 - Add llama3 instruction formatting option when loading a ShareGPT formatted dataset using Axolotl. - Automatically add BOS token for Llama 3. - Add option for Unsloth activation checkpointing, which saves VRAM for a very small hit to performance. ### 2024-04-16 - Optimizer is now specified in the config.toml file. - Can use AdamW8Bit optimizer. - MLP offloading works again. For MoE, can offload a specified number of experts. - Can have separate dtype for saved files. - Cohere model support (command-r) ### 2024-04-07 Make sure to update requirements! Axolotl does some dynamic importing, so things will break in a very hard to diagnose way if you don't have a new dependency that was added. - Removed the need for manually specifying cache directories for datasets. All dataset processing uses the Huggingface Datasets library and takes advantage of the automatic caching that it provides. - Added the ability to specify multiple datasets, with different ways to combine them. __This breaks the old config format for datasets.__ Refer to the example config for what it should look like now. ================================================ FILE: examples/alpaca.yml ================================================ datasets: - path: vicgalle/alpaca-gpt4 type: alpaca ================================================ FILE: examples/capybara.yml ================================================ chat_template: llama3 datasets: - path: ssmi153/Capybara-ShareGPT type: chat_template field_messages: conversations message_field_role: from message_field_content: value ================================================ FILE: examples/config.toml ================================================ # Paths model = '/data2/models/Meta-Llama-3.1-8B' output_dir = '/data/training_runs/llama3_8b_example' # Lora configuration # can use full_fine_tune=true and no quantization to train the whole model instead of a LoRA #full_fine_tune = true lora_rank = 64 lora_alpha = 64 lora_dropout = 0.05 # Train only specific modules. This is passed to the parameter of the same name in the LoraConfig. # If not set, adapt all linear modules. # Note, this ALSO affects full fine tuning. In that case, if this is set, only weights containing one # of these keys as substring will have requires_grad. If not set everything is trained. #target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'] # can specify layers to adapt with LoRA if you want #layers_to_transform = '16:31' # Training settings epochs = 2 lr_scheduler = 'cosine' # can also be 'constant' warmup_steps = 100 # Batch size of a single forward/backward pass for one GPU. micro_batch_size_per_gpu = 1 # Dynamic batch size, targeting this many tokens per batch, per device. # If set, completely ignores micro_batch_size_per_gpu. # Can be thought of as a replacement for sample packing. batch_size_tokens = 5000 # Number of pipeline parallel stages, must evenly divide the number of GPUs you launch the script with. The lower this is, the more GPU VRAM you need, but the faster the training will be. # A value of 1 means the model will be loaded on every GPU in full, each GPU running its own training batches. # A value of 2 means the model will be split into half, the halves loaded evenly across the (2, 4, 6, 8, ...) GPUs, where GPUs will work in pairs on each batch. (And so on.) pipeline_stages = 1 # Number of micro-batches sent through the pipeline for each training step. # If pipeline_stages > 1, a higher GAS means better GPU utilization due to smaller pipeline bubbles (where GPUs aren't overlapping computation). gradient_accumulation_steps = 4 # Grad norm clipping. gradient_clipping = 1.0 # might be useful if resuming from a checkpoint and you want to change the LR and force it to something #force_constant_lr = 5e-5 # hard clamp the magnitude of the LoRA weights #scale_weight_norms = 1.0 # for Mixtral, set the load balancing coefficient #load_balancing_loss_coef = 0.02 # Eval settings eval_steps = 100 # how often to run eval eval_before_first_step = true # do an eval before any training happens eval_after_last_step = false # do a final eval after the training completes # Performance settings logging_steps = 10 # how often to log in Tensorboard save_steps = 200 # how often to save the model checkpoint_every_n_minutes = 60 # how frequently to checkpoint training states (used for resuming training) # checkpoint_on_save = true # alternative to the above, this will cause a checkpoint save every time a regular save occurs (note: this setting takes precedence over checkpoint_every_n_minutes) # dtype to load the underlying model weights in model_weight_dtype = 'bfloat16' # dtype for the LoRA weights lora_weight_dtype = 'bfloat16' # Can have the saved weights be different dtype. Don't need to set this. Could be useful for # training in float32 but saving with float16. #save_dtype = 'bfloat16' # Keep this number of stepXXXX (model saves) and global_stepXXX (checkpoint saves) and delete the rest # (this only applies to the current training session, and resumed training sessions will not touch # old saves) keep_states = 3 # sort examples by length before dividing them into batches # this makes all examples in a batch approximately the same length, to minimize padding # the batches are still shuffled after that # you should probably always have this set to true group_by_length = true # This can also be 'unsloth' to offload hidden states to CPU, saving potentially a lot of VRAM # for a minor performance hit. # Example: 4x4090, PCIE 3.0 16x, pipeline_stages=4, training QLoRA on Llama 3 70B with 4096 sequence length. # true: 75s step time, 19.7G peak per-GPU VRAM usage. # 'unsloth': 78s step time, 16.2G peak per-GPU VRAM usage. activation_checkpointing = true # Keep MLP weights on system RAM until they are needed. Can save a ton of VRAM with a # moderate hit to performance. If using an MoE model, this can also be an integer, in # which case only that many experts are offloaded (tradeoff between VRAM and speed). #offload_mlp_to_cpu = true # Resume a prior run # if true, we attempt to resume training from the most recent directory inside output_dir (the directory names are timestamps) # so, to resume, just run the exact same command but set this to true first resume_from_checkpoint = false # Loading the optimizer states seems to cause some kind of unavoidable VRAM memory leak. # It's very small, only about 0.2 GB in cases I've seen. But if you are very close to the # limit, it can cause resuming from checkpoint to OOM. As a last resort, you can uncomment # this to not load the optimizer states and hopefully the resumption won't OOM. #load_optimizer_states = false # Dataset configuration # How to combine multiple datasets if you have more than one. # Can be 'concatenate' or 'interleave'. Will be 'concatenate' if not set. dataset_combination_mode = 'interleave' # When to stop interleaving datasets when using mode 'interleave'. Either 'first_exhausted' or 'all_exhausted'. # Default if not set: 'first_exhausted' dataset_interleave_stopping_strategy = 'all_exhausted' # Can set this lower than training, so we don't drop as many examples when trying to make equal-sized batches. # Default if not set: same as training GAS. eval_gradient_accumulation_steps = 1 # bitsandbytes 4 bit quantization. The parameters here become arguments to Transformers BitsAndBytesConfig. [quantization.bnb] load_in_4bit = true bnb_4bit_use_double_quant = false bnb_4bit_compute_dtype = 'bfloat16' # HQQ quantization. The parameters here become arguments to CustomHQQConfig. # [quantization.hqq] # nbits = 4 # group_size = 64 # compute_dtype = 'bfloat16' # (Optional) You can override the quant params for certain modules. This does substring matching, e.g. if 'gate_proj' # is a substring of the full module name, anything specified overwrites the defaults in [quantization.hqq]. # [quantization.hqq.dynamic_config] # gate_proj = {nbits = 2, group_size = 16, quant_zero = true, quant_scale = true} # up_proj = {nbits = 2, group_size = 16, quant_zero = true, quant_scale = true} # down_proj = {nbits = 2, group_size = 16, quant_zero = true, quant_scale = true} [optimizer] # options: adamw_kahan, AdamW, AdamW8bit type = 'adamw_kahan' lr = 5e-5 beta1 = 0.9 beta2 = 0.99 weight_decay = 0.1 [[datasets]] # Arbitrary name, used only for separately logging eval metrics. Will be dataset0, dataset1, etc if not set. name = 'alpaca' dataset_type = 'axolotl' dataset_path = 'examples/alpaca.yml' sequence_len = 2048 eval_size = 0.02 # Relative sampling weight, when using combination mode 'interleave'. Will be 1 if not set. sample_weight = 1 [[datasets]] name = 'capybara' dataset_type = 'axolotl' dataset_path = 'examples/capybara.yml' sequence_len = 2048 eval_size = 0.02 sample_weight = 1.5 # In addition to using eval_size which splits off some of the dataset, we can have completely separate datasets for eval. # This can be useful if you're training on raw text data, so that the eval set remains completely fixed, even if # you change training sequence_len, etc. # This is just an example, typically you wouldn't have this overlap a training dataset. # [[eval_datasets]] # name = 'capybara' # dataset_type = 'axolotl' # dataset_path = 'examples/capybara.yml' # sequence_len = 2048 ================================================ FILE: examples/config_dpo.toml ================================================ # Paths model = '/data2/models/Meta-Llama-3.1-8B-Instruct' output_dir = '/data/training_runs/llama3_8b_dpo_example' lora_rank = 64 lora_alpha = 64 lora_dropout = 0.05 epochs = 2 lr_scheduler = 'constant' warmup_steps = 100 batch_size_tokens = 5000 micro_batch_size_per_gpu = 1 pipeline_stages = 1 gradient_accumulation_steps = 4 gradient_clipping = 1.0 eval_steps = 100 eval_before_first_step = true eval_after_last_step = false logging_steps = 10 save_steps = 200 checkpoint_every_n_minutes = 60 model_weight_dtype = 'bfloat16' lora_weight_dtype = 'bfloat16' group_by_length = true activation_checkpointing = true eval_gradient_accumulation_steps = 1 [rl] method = 'dpo' dpo_beta = 0.02 # [quantization.bnb] # load_in_4bit = true # bnb_4bit_use_double_quant = false # bnb_4bit_compute_dtype = 'bfloat16' [optimizer] type = 'adamw_kahan' lr = 5e-5 beta1 = 0.9 beta2 = 0.99 weight_decay = 0.1 [[datasets]] name = 'ultrafeedback' dataset_type = 'axolotl' dataset_path = 'examples/ultrafeedback.yml' sequence_len = 4096 eval_size = 0.01 ================================================ FILE: examples/converted_dpo_dataset.yml ================================================ # Some DPO datasets are not in conversation format. # They need to be in conversation format to load them in this script. You can convert them: # python tools/convert_dpo_dataset_to_chat_format.py unalignment/toxic-dpo-v0.2 ~/data/toxic-dpo-v0.2-converted chat_template: llama3 datasets: - path: json data_files: - /home/anon/data/toxic-dpo-v0.2-converted/train.json split: train type: orpo.chat_template ================================================ FILE: examples/ds_config.json ================================================ { "train_micro_batch_size_per_gpu": 1, "gradient_accumulation_steps": 1, "gradient_clipping": 1.0, "steps_per_print": 1 } ================================================ FILE: examples/ultrafeedback.yml ================================================ # This dataset is already in a format that can be directly loaded by the orpo.chat_template type. chat_template: llama3 datasets: - path: argilla/ultrafeedback-binarized-preferences-cleaned type: orpo.chat_template ================================================ FILE: kernels/cross_entropy_loss.py ================================================ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import triton import triton.language as tl from .utils import MAX_FUSED_SIZE, calculate_settings, device_warp_size @triton.heuristics( { 'DO_LOGIT_SCALING': lambda args: args['DO_LOGIT_SCALING'], } ) @triton.jit def _cross_entropy_forward( logits_ptr, logits_row_stride, loss_ptr, logsumexp_ptr, labels_ptr, VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, DO_LOGIT_SCALING: tl.constexpr, LOGIT_SCALE: tl.constexpr, ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] Pi = exp(xi) / sum(exp(xi)) CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ] = -y [ x - log[sum(exp(x))] ] = y * (log[sum(exp(x))] - x) If y == 0: CE_i = 0 If y == 1: CE_i = logsumexp - x logsumexp is also stable Take y = log[sum(exp(x))] exp(y) = sum(exp(x)) exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x exp(y) = exp(c)*sum(exp(x - c)) y = log(exp(c)*sum(exp(x - c))) y = c + log[sum(exp(x - c))] This means we can set c = max(x) to make sure exp(x - c) always is exp(x - max(x)). This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1. """ row_idx = tl.program_id(0) logits_ptr += row_idx * logits_row_stride.to(tl.int64) loss_ptr += row_idx logsumexp_ptr += row_idx labels_ptr += row_idx col_offsets = tl.arange(0, BLOCK_SIZE) mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr).to(tl.int32) logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float('inf')).to(tl.float32) if DO_LOGIT_SCALING: # Logit scaling: s * x logits = LOGIT_SCALE * logits pass c = tl.max(logits, 0) logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0)) if label_idx != -100: x = tl.load(logits_ptr + label_idx).to(tl.float32) if DO_LOGIT_SCALING: # Logit scaling: s * x x = LOGIT_SCALE * x pass loss = logsumexp - x else: loss = 0.0 tl.store(logsumexp_ptr, logsumexp) tl.store(loss_ptr, loss) pass @triton.heuristics( { 'DO_LOGIT_SCALING': lambda args: args['DO_LOGIT_SCALING'], } ) @triton.jit def _chunked_cross_entropy_forward( logits_ptr, logits_row_stride, loss_ptr, logsumexp_ptr, labels_ptr, VOCAB_SIZE: tl.constexpr, N_CHUNKS: tl.constexpr, BLOCK_SIZE: tl.constexpr, DO_LOGIT_SCALING: tl.constexpr, LOGIT_SCALE: tl.constexpr, ): """ 256K vocab divided in 4 chunks |-65536-| |-65536-| |-65536-| |-65536-| |-------| |-------| |-------| |-------| |-------| |-------| |-------| |-------| If y == 0: CE_i = 0 If y == 1: CE_i = logsumexp - x Notice we can do logsumexp for each chunk and then logsumexp[chunk_sum(logsumexp)] == logsumexp chunk_sum = log[chunk_sum(logsumexp)] = log[exp(logsumexp(a)) + ... + exp(logsumexp(z))] = log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])] = log[sum(exp(a)) + ... + sum(exp(z))] = logsumexp(x) This means we can perform a logsumexp for each chunk, then do a final logsumexp reduction! Ie do: logsumexp(chunked_logsumexp) - x """ row_idx = tl.program_id(0) chunk_idx = tl.program_id(1) logits_ptr += row_idx * logits_row_stride.to(tl.int64) loss_ptr += row_idx logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx labels_ptr += row_idx col_offsets = chunk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr).to(tl.int32) logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float('inf')).to(tl.float32) if DO_LOGIT_SCALING: # Logit scaling: s * x logits = LOGIT_SCALE * logits pass c = tl.max(logits, 0) logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0)) if chunk_idx == 0: # logsumexp(chunked_logsumexp) - x # Do the -x separately if label_idx != -100: x = tl.load(logits_ptr + label_idx).to(tl.float32) if DO_LOGIT_SCALING: # Logit scaling: s * x x = LOGIT_SCALE * x pass loss = -1.0 * x else: loss = 0.0 tl.store(loss_ptr, loss) pass tl.store(logsumexp_ptr, logsumexp) pass @triton.heuristics( { 'DO_LOGIT_SCALING': lambda args: args['DO_LOGIT_SCALING'], } ) @triton.jit def _cross_entropy_backward( logits_ptr, logits_row_stride, dloss_ptr, dloss_row_stride, logsumexp_ptr, labels_ptr, VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, DO_LOGIT_SCALING: tl.constexpr, LOGIT_SCALE: tl.constexpr, ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) dC/dx = d/dx (y * log[sum(exp(x))] - x * y) From https://en.wikipedia.org/wiki/LogSumExp d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x) dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y) dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick dC/dx = y * exp[x - logsumexp] - d/dx (x * y) If y == 0: dC/dx = 0 If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1 If y == 1 and x != label: dC/dx = exp[x - logsumexp] """ row_idx = tl.program_id(0) block_idx = tl.program_id(1) logits_ptr += row_idx * logits_row_stride.to(tl.int64) dloss_ptr += row_idx * dloss_row_stride col_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr + row_idx).to(tl.int32) if label_idx != -100: dloss = tl.load(dloss_ptr) else: dloss = 0.0 x = tl.load(logits_ptr + col_offsets, mask=mask, other=-float('inf')).to(tl.float32) if DO_LOGIT_SCALING: # d/dx [s * x] = s x = LOGIT_SCALE * x pass logsumexp = tl.load(logsumexp_ptr + row_idx) y = tl.exp(x - logsumexp) y = tl.where( col_offsets == label_idx, y - 1.0, # exp(x - logsumexp) - 1 y, # exp(x - logsumexp) ) # If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0. if DO_LOGIT_SCALING: # d/dx [s * x] = s y = LOGIT_SCALE * y pass tl.store(logits_ptr + col_offsets, dloss * y, mask=mask) pass class Fast_CrossEntropyLoss(torch.autograd.Function): @staticmethod def forward(ctx, logits, labels, logit_scale=1.0): n_rows, vocab_size = logits.shape div, mod = divmod(vocab_size, MAX_FUSED_SIZE) n_chunks = div + (mod != 0) losses = torch.empty(n_rows, dtype=torch.float32, device='cuda') if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral BLOCK_SIZE, num_warps = calculate_settings(vocab_size) logsumexp = torch.empty(n_rows, dtype=torch.float32, device='cuda') _cross_entropy_forward[(n_rows,)]( logits, logits.stride(0), losses, logsumexp, labels, VOCAB_SIZE=vocab_size, BLOCK_SIZE=BLOCK_SIZE, DO_LOGIT_SCALING=(logit_scale != 1.0), LOGIT_SCALE=logit_scale, num_warps=num_warps, ) else: # For large vocabs > 65336 like Gemma 256K logsumexp = torch.empty( ( n_rows, n_chunks, ), dtype=torch.float32, device='cuda', ) _chunked_cross_entropy_forward[ ( n_rows, n_chunks, ) ]( logits, logits.stride(0), losses, logsumexp, labels, VOCAB_SIZE=vocab_size, N_CHUNKS=n_chunks, BLOCK_SIZE=MAX_FUSED_SIZE, DO_LOGIT_SCALING=(logit_scale != 1.0), LOGIT_SCALE=logit_scale, num_warps=32 if device_warp_size() < 64 else 16, ) # logsumexp(chunked_logsumexp) - x # Do the -x separately logsumexp = torch.logsumexp(logsumexp, dim=1) # Row sum losses += logsumexp losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out! pass ctx.save_for_backward(logits, logsumexp, labels) ctx.logit_scale = logit_scale return losses pass @staticmethod def backward(ctx, dlosses): logits, logsumexp, labels = ctx.saved_tensors n_rows, vocab_size = logits.shape BLOCK_SIZE = 4096 div, mod = divmod(vocab_size, BLOCK_SIZE) n_blocks = div + (mod != 0) _cross_entropy_backward[ ( n_rows, n_blocks, ) ]( logits, logits.stride(0), dlosses, dlosses.stride(0), logsumexp, labels, VOCAB_SIZE=vocab_size, BLOCK_SIZE=BLOCK_SIZE, DO_LOGIT_SCALING=(ctx.logit_scale != 1.0), LOGIT_SCALE=ctx.logit_scale, num_warps=8, ) return ( logits, None, None, ) pass pass def fast_cross_entropy_loss(logits, labels, logit_scale=1.0): """ Arguments: logits: (batch, seq_len, vocab_size) labels: (batch, seq_len,) Returns: losses: float """ batch, seq_len, d = logits.shape assert labels.shape == (batch, seq_len) loss = Fast_CrossEntropyLoss.apply( logits.view(batch * seq_len, d), labels.view(-1), logit_scale, ) n_items = torch.count_nonzero(labels != -100) return loss.sum() / n_items pass ================================================ FILE: kernels/utils.py ================================================ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ctypes import bitsandbytes as bnb import torch import triton MAX_FUSED_SIZE = 65536 next_power_of_2 = triton.next_power_of_2 def device_warp_size(): if torch.cuda.is_available() and torch.version.hip and torch.cuda.get_device_capability(0)[0] < 10: return 64 else: return 32 def calculate_settings(n): BLOCK_SIZE = next_power_of_2(n) if BLOCK_SIZE > MAX_FUSED_SIZE: raise RuntimeError( f'Cannot launch Triton kernel since n = {n} exceeds the maximum CUDA blocksize = {MAX_FUSED_SIZE}.' ) warp_scalar = 32 / float(device_warp_size()) num_warps = 4 if BLOCK_SIZE >= 32768: num_warps = 32 * warp_scalar elif BLOCK_SIZE >= 8192: num_warps = 16 * warp_scalar elif BLOCK_SIZE >= 2048: num_warps = 8 * warp_scalar return BLOCK_SIZE, int(num_warps) pass get_ptr = bnb.functional.get_ptr cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4 cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4 cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16 cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16 def QUANT_STATE(W): return getattr(W, 'quant_state', None) pass def get_lora_parameters(proj): # For DPO or disabled adapters base_layer = proj.base_layer if hasattr(proj, 'base_layer') else proj W = base_layer.weight if not hasattr(proj, 'disable_adapters') or proj.disable_adapters or proj.merged: return W, QUANT_STATE(W), None, None, None pass active_adapter = proj.active_adapters[0] if hasattr(proj, 'active_adapters') else proj.active_adapter A = proj.lora_A[active_adapter].weight B = proj.lora_B[active_adapter].weight s = proj.scaling[active_adapter] return W, QUANT_STATE(W), A, B, s pass def fast_dequantize(W, quant_state=None, out=None): if quant_state is None: return W if type(quant_state) is not list: # New quant_state as a class # https://github.com/TimDettmers/bitsandbytes/pull/763/files absmax = quant_state.absmax shape = quant_state.shape dtype = quant_state.dtype blocksize = quant_state.blocksize offset = quant_state.offset state2 = quant_state.state2 absmax2 = state2.absmax code2 = state2.code blocksize2 = state2.blocksize else: # Old quant_state as a list of lists absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state offset, state2 = compressed_stats absmax2, code2, blocksize2, _, _, _, _ = state2 pass # Create weight matrix if out is None: out = torch.empty(shape, dtype=dtype, device='cuda') else: assert out.shape == shape assert out.dtype == dtype # NF4 dequantization of statistics n_elements_absmax = absmax.numel() out_absmax = torch.empty(n_elements_absmax, dtype=torch.float32, device='cuda') # Do dequantization ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax), ) out_absmax += offset fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else cdequantize_blockwise_bf16_nf4 fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), ctypes.c_int(blocksize), ctypes.c_int(out.numel())) # Careful returning transposed data is_transposed = True if W.shape[0] == 1 else False return out.t() if is_transposed else out pass def fast_gemv(X, W, quant_state, out=None): if quant_state is None: return torch.matmul(X, W, out=out) # For fast X @ W where seq_len == 1 # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469 _, q_len, hd = X.shape # assert(q_len == 1) if type(quant_state) is not list: # https://github.com/TimDettmers/bitsandbytes/pull/763/files absmax = quant_state.absmax shape = quant_state.shape dtype = quant_state.dtype blocksize = quant_state.blocksize stats = quant_state.code offset = quant_state.offset state2 = quant_state.state2 absmax2 = state2.absmax code2 = state2.code blocksize2 = state2.blocksize else: absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state offset, state2 = compressed_stats absmax2, code2, blocksize2, _, _, _, _ = state2 pass # assert(dtype == X.dtype) bout = shape[0] if out is None: out = torch.empty( ( 1, 1, bout, ), dtype=dtype, device='cuda', ) # else: # assert(out.shape == (1, 1, bout,)) # pass n = 1 m = shape[0] k = shape[1] lda = shape[0] ldc = shape[0] ldb = (hd + 1) // 2 m = ctypes.c_int32(m) n = ctypes.c_int32(n) k = ctypes.c_int32(k) lda = ctypes.c_int32(lda) ldb = ctypes.c_int32(ldb) ldc = ctypes.c_int32(ldc) df = torch.empty(absmax.shape, dtype=torch.float32, device='cuda') cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df), ctypes.c_int(blocksize2), ctypes.c_int(df.numel()), ) df += offset absmax = df fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else cgemm_4bit_inference_naive_bf16 blocksize = ctypes.c_int32(blocksize) fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), lda, ldb, ldc, blocksize) return out pass def fast_linear_forward(proj, X, temp_lora=None, out=None): W, W_quant, lora_A, lora_B, lora_S = get_lora_parameters(proj) bsz, q_len, in_dim = X.shape if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S) if W_quant is None: out = torch.matmul(X, W.t(), out=out) elif bsz == 1 and q_len == 1: out = fast_gemv(X, W, W_quant, out=out) else: W = fast_dequantize(W.t(), W_quant) out = torch.matmul(X, W, out=out) pass # Add in LoRA weights if lora_A is not None: out_dim = out.shape[2] dtype = X.dtype if not hasattr(lora_A, '_fast_lora'): lora_A._fast_lora = lora_A.to(dtype) lora_B._fast_lora = lora_B.to(dtype) pass if bsz == 1: out = out.view(out_dim) temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out=temp_lora) out.addmv_(lora_B._fast_lora, temp_lora, alpha=lora_S) else: out = out.view(bsz, out_dim) temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out=temp_lora) out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha=lora_S) pass out = out.view(bsz, 1, out_dim) pass return out pass def matmul_lora(X, W, W_quant, A, B, s, out=None): dtype = X.dtype W = fast_dequantize(W.t(), W_quant) if X.dim() == 3: batch, seq_len, d = X.shape X = X.view(-1, X.shape[-1]) reshape = True else: reshape = False pass out = torch.matmul(X, W, out=out) if W_quant is not None: del W if A is not None: # LoRA is enabled A, B = A.t(), B.t() out += (X @ A.to(dtype)) @ (s * B.to(dtype)) pass return out.view(batch, seq_len, -1) if reshape else out pass ================================================ FILE: models/layers.py ================================================ import math import torch import torch.nn.functional as F import transformers from deepspeed.runtime.pipe import module as ds_pipe_module from torch import nn from kernels.cross_entropy_loss import Fast_CrossEntropyLoss def move_data_to_device(module, device): non_blocking = (device != 'cpu') # handle lora if hasattr(module, 'base_layer'): module = module.base_layer # handle HQQ if hasattr(module, 'W_q'): orig_data = module.W_q.data module.W_q.data = orig_data.to(device, non_blocking=non_blocking) else: orig_data = module.weight.data module.weight.data = orig_data.to(device, non_blocking=non_blocking) return orig_data def set_data(module, data): # handle lora if hasattr(module, 'base_layer'): module = module.base_layer # handle HQQ if hasattr(module, 'W_q'): module.W_q.data = data else: module.weight.data = data def move_experts_to_device(experts, device, num_experts_to_offload): orig_data = [] for i in range(num_experts_to_offload): orig_w1 = move_data_to_device(experts[i].w1, device) orig_w2 = move_data_to_device(experts[i].w2, device) orig_w3 = move_data_to_device(experts[i].w3, device) orig_data.append((orig_w1, orig_w2, orig_w3)) return orig_data def set_experts_data(experts, orig_data): for i, (orig_w1, orig_w2, orig_w3) in enumerate(orig_data): set_data(experts[i].w1, orig_w1) set_data(experts[i].w2, orig_w2) set_data(experts[i].w3, orig_w3) def entropy_fn(logits): result = [] # There is a very wide range of chuck sizes that cause no increase in memory reported by # nvidia-smi (Torch re-using blocks of memory?). If you try to compute it as one tensor, # memory usage is huge. Chuck size of 128 seems good enough for now. for logits_chuck in torch.split(logits, 128): result.append(torch.distributions.Categorical(logits=logits_chuck).entropy()) return torch.cat(result).float() def top_k_accuracy(logits, labels, k_list, ignore_index=-100): keep = labels != ignore_index labels = labels[keep].view(-1, 1) max_k = max(k_list) _, top_k_predictions = torch.topk(logits, max_k, dim=-1, sorted=True) top_k_predictions = top_k_predictions[keep] accuracies = [] for k in k_list: accuracies.append(torch.any(top_k_predictions[:, :k] == labels, dim=-1).to(torch.float32).mean()) return accuracies class LayerSpec(ds_pipe_module.LayerSpec): def __init__(self, typename, *module_args, **module_kwargs): super().__init__(typename, *module_args, **module_kwargs) def build(self): self.module_kwargs.pop('_estimated_size', None) return self.typename(*self.module_args, **self.module_kwargs) @property def estimated_size(self): return self.module_kwargs.get('_estimated_size', 1) # TODO: consider using Liger-Kernel fused loss implementations. The inputs are already set up to support this. # Would save VRAM, but some metrics could no longer be computed (e.g. entropy, accuracies). class OutputLayer(nn.Module): def __init__( self, pipeline_model, loader_util, lm_head, logit_scale=1.0, loss_type='cross_entropy_loss', focal_loss_gamma=0, tie_weights=None, logit_softcapping=None, ): super().__init__() # Assign list to prevent registering the nn.Module self.pipeline_model = [pipeline_model] # Unlike the other wrapper classes, this is called lm_head and not orig. Because this is directly a # nn.Linear layer, it needs to keep the same attribute name so quantization knows not to quantize it. self.lm_head = lm_head self.logit_scale = logit_scale self.loss_type = loss_type.lower() self.focal_loss_gamma = focal_loss_gamma if tie_weights: self.lm_head.weight.original_name = tie_weights self.logit_softcapping = logit_softcapping loader_util.load_state_dict_into_module(self) if self.loss_type == 'cross_entropy_loss' and self.focal_loss_gamma != 0: raise ValueError("focal_loss_gamma can't be used with 'cross_entropy_loss' function") def forward(self, inputs): hidden_states, labels = inputs if self.pipeline_model[0].sampling_mode: # When sampling only compute the last logits. hidden_states = hidden_states[:, -1:, :] labels = labels.to(hidden_states.device) if self.logit_scale == 1.0: logits = self.lm_head(hidden_states) else: logits = self.lm_head(self.logit_scale * hidden_states) if self.logit_softcapping is not None and self.logit_softcapping > 0: logits = logits / self.logit_softcapping logits = torch.tanh(logits) logits = logits * self.logit_softcapping if self.pipeline_model[0].sampling_mode: return logits extra_ignored_labels = torch.full((labels.shape[0], 1), -100, device=logits.device) labels = torch.hstack((labels[..., 1:], extra_ignored_labels)) # Flatten the tokens vocab_size = logits.size(-1) flat_logits = logits.view(-1, vocab_size) flat_labels = labels.view(-1) flat_loss_mask = flat_labels >= 0 cross_entropy_loss = Fast_CrossEntropyLoss.apply(flat_logits, flat_labels) loss = None if self.loss_type == 'cross_entropy_loss': cross_entropy_loss = cross_entropy_loss[flat_loss_mask] loss_unreduced = cross_entropy_loss elif self.loss_type == 'focal_loss': cross_entropy_loss = cross_entropy_loss[flat_loss_mask] # See https://arxiv.org/abs/1708.02002 (Section 3) p = torch.exp(-cross_entropy_loss) loss_unreduced = (1 - p) ** self.focal_loss_gamma * cross_entropy_loss elif self.loss_type == 'focal_loss_star': cross_entropy_loss = cross_entropy_loss[flat_loss_mask] # See https://arxiv.org/abs/1708.02002 (Appendix A/B) # NOTE: The use of Beta makes no sense for the multinomial case as it's invariant to translation loss_unreduced = Fast_CrossEntropyLoss.apply(flat_logits, flat_labels, self.focal_loss_gamma) loss_unreduced = loss_unreduced[flat_loss_mask] loss_unreduced = loss_unreduced / self.focal_loss_gamma elif self.loss_type == 'inverse_focal_loss': cross_entropy_loss = cross_entropy_loss[flat_loss_mask] # See "Rethinking Calibration of Deep Neural Networks: Do Not Be Afraid of Overconfidence" (Section 5.2) # NOTE: The alternative of p^gamma (instead of (1+p)^gamma) might be useful for gradient ascent... p = torch.exp(-cross_entropy_loss) loss_unreduced = (1 + p) ** self.focal_loss_gamma * cross_entropy_loss elif self.loss_type == 'exponentiated_cross_entropy_loss': cross_entropy_loss = cross_entropy_loss[flat_loss_mask] # See "Gradient as a Foundation for Building a Loss Function" (Section III.B) # NOTE: This is a generalisation of their "Quadratic Cross-Entropy" loss (QCE: gamma=2, CE: gamma=1, etc). loss_unreduced = cross_entropy_loss**self.focal_loss_gamma / self.focal_loss_gamma elif self.loss_type == 'dpo': rl_config = self.pipeline_model[0].train_config['rl'] cross_entropy_loss = cross_entropy_loss.view_as(labels) # unflatten loss_mask = labels >= 0 logps = -(cross_entropy_loss * loss_mask).sum(-1) half = cross_entropy_loss.size(0) // 2 chosen_logps = logps[:half] rejected_logps = logps[half:] if self.pipeline_model[0].dpo_reference_mode: self.reference_chosen_logps = chosen_logps.detach() self.reference_rejected_logps = rejected_logps.detach() return torch.tensor(0.0, device=logits.device) # log the language modeling loss metrics on the chosen completion cross_entropy_loss = cross_entropy_loss[:half].flatten()[loss_mask[:half].flatten()] hidden_states = hidden_states[:half] loss_unreduced = cross_entropy_loss flat_logits = logits[:half].view(-1, vocab_size) flat_labels = labels[:half].view(-1) flat_loss_mask = flat_labels >= 0 policy_chosen_logps = chosen_logps policy_rejected_logps = rejected_logps pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = self.reference_chosen_logps - self.reference_rejected_logps del self.reference_chosen_logps del self.reference_rejected_logps dpo_logits = pi_logratios - ref_logratios loss = -F.logsigmoid(rl_config['dpo_beta'] * dpo_logits).mean() else: raise NotImplementedError(self.loss_type) with torch.no_grad(): log_vocab_size = math.log(logits.size(-1)) entropy = entropy_fn(flat_logits)[flat_loss_mask] # Compute normalised entropy so we can compare between models with different vocab sizes normalised_entropy = entropy / log_vocab_size # Compute the (negative) log-likelihood using the original *UNADJUSTED* Cross-Entropy loss. log_likelihood = cross_entropy_loss.mean() # Compute McFadden's Pseudo-R² metric using log(vocab_size) as the null log-likelihood. mcfaddens_pseudo_r2 = 1 - (log_likelihood / log_vocab_size) accuracies = top_k_accuracy(flat_logits, flat_labels, k_list=[1, 5, 20]) # Compute the norms of the (pre-logit-scaled) hidden states hidden_state_norms = torch.norm(hidden_states.float(), dim=-1) hidden_state_norms = hidden_state_norms.view(-1)[flat_loss_mask] if loss is None: # Normal language modeling loss types (e.g. not DPO) loss = loss_unreduced.mean() loss_unreduced = loss_unreduced.detach() return ( loss, loss_unreduced, hidden_state_norms, entropy, normalised_entropy, log_likelihood, mcfaddens_pseudo_r2, *accuracies, ) def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float: if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device stacked_gate_logits = torch.stack([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) routing_weights = torch.nn.functional.softmax(stacked_gate_logits, dim=-1) # [num_layers, num_tokens, num_experts] _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) # [num_layers, num_tokens, top_k] expert_mask = torch.nn.functional.one_hot( selected_experts, num_experts ) # [num_layers, num_tokens, top_k, num_experts] # For a given token, determine if it was routed to a given expert. Think of this as a collection of top_k-hot vectors. expert_mask = torch.max(expert_mask, dim=-2).values.float() # [num_layers, num_tokens, num_experts] tokens_per_layer_and_expert = torch.mean(expert_mask, dim=-2) # [num_layers, num_experts] router_prob_per_layer_and_expert = torch.mean(routing_weights, dim=-2) # [num_layers, num_experts] return torch.mean(tokens_per_layer_and_expert * router_prob_per_layer_and_expert) * num_experts**2 class MixtralOutputLayer(OutputLayer): def __init__( self, pipeline_model, loader_util, lm_head, load_balancing_loss_coef, num_experts, num_experts_per_tok, **kwargs, ): super().__init__(pipeline_model, loader_util, lm_head, **kwargs) self.load_balancing_loss_coef = load_balancing_loss_coef self.num_experts = num_experts self.num_experts_per_tok = num_experts_per_tok def forward(self, inputs): hidden_states, labels, *router_logits = inputs router_logits = tuple(router_logits) outputs = super().forward((hidden_states, labels)) if self.pipeline_model[0].sampling_mode: return outputs if self.load_balancing_loss_coef is not None: aux_loss = transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func( router_logits, self.num_experts, self.num_experts_per_tok ) alternate_aux_loss = load_balancing_loss_func(router_logits, self.num_experts, self.num_experts_per_tok) loss = outputs[0] loss += self.load_balancing_loss_coef * aux_loss outputs = (loss, *outputs[1:], aux_loss, alternate_aux_loss) return outputs class InputLayer(nn.Module): def __init__(self, model): super().__init__() self._model = [model] self.embed_tokens = model.model.embed_tokens self.rotary_emb = model.model.rotary_emb self.embedding_on_cpu = not self.model.train_config['full_fine_tune'] self.model.loader_util.load_state_dict_into_module(self) @property def model(self): return self._model[0] def forward(self, inputs): past_key_values = None cache_position = None use_cache = self.model.sampling_mode input_ids, attention_mask, labels = inputs[:3] device = input_ids.device if self.embedding_on_cpu: self.embed_tokens.to('cpu') input_ids = input_ids.to('cpu') inputs_embeds = self.embed_tokens(input_ids).to(device) if use_cache: past_key_values = self.model.cache past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) position_ids = cache_position.unsqueeze(0) attention_mask = self.model.model._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, None ) if attention_mask is None: # attention_mask can end up being None, which means use full causal attention. But with pipeline parallelism, # we can only pass tensors between layers. So make it an empty tensor, which will later be detected by the layers # and converted back to None. Note: this only works now because dynamic_shape=True in the pipeline engine. attention_mask = torch.tensor([], device=device) # Work around a very strange Deepspeed bug. The combination of PipelineModule dynamic_shape=True, attention_mask being # an integer dtype, and pipeline_stages>2 causes training (but not eval) to hang. So cast to float, and cast back to int # in the layer. if torch.is_tensor(attention_mask) and not torch.is_floating_point(attention_mask): attention_mask = attention_mask.to(inputs_embeds.dtype) hidden_states = inputs_embeds if self.model.model.config.model_type == 'gemma2': normalizer = torch.tensor(self.model.model.config.hidden_size**0.5, dtype=hidden_states.dtype) hidden_states = hidden_states * normalizer cos, sin = self.rotary_emb(hidden_states, position_ids) output = hidden_states, attention_mask, cos, sin, cache_position, labels # Deepspeed requirement. Float tensors must require grad. for tensor in output: if torch.is_floating_point(tensor): tensor.requires_grad_(True) return output class LlamaRMSNormPipe(nn.Module): def __init__(self, loader_util, orig): super().__init__() self.orig = orig loader_util.load_state_dict_into_module(self) def forward(self, inputs): hidden_states, _, _, _, _, labels, *router_logits = inputs return self.orig(hidden_states), labels, *router_logits class LlamaDecoderLayerPipe(nn.Module): def __init__(self, pipeline_model, loader_util, orig): super().__init__() self.pipeline_model = [pipeline_model] self.orig = orig self.mlp_offloaded_to_cpu = False self.attn_implementation = pipeline_model.config._attn_implementation loader_util.load_state_dict_into_module(self) # A note on MLP offloading: # We take advantage of how activation checkpointing works with reentrant checkpointing functions. # During the forward pass, if gradients are disabled (eval or first forward pass of activation checkpointing) # we offload the weights back to CPU at the end of the function. If gradients are enabled (second forward pass # of activation checkpointing) we leave the weights on GPU, and use a backward hook to offload to CPU after the # backward pass of this function is completed. This way the weights stay on the GPU for the backward pass. def forward(self, inputs): def move_mlp_to_cpu_hook(grad): self.move_mlp_to_cpu() return None hidden_states, attention_mask, cos, sin, cache_position, labels = inputs if self.mlp_offloaded_to_cpu: if hidden_states.requires_grad: hidden_states.register_hook(move_mlp_to_cpu_hook) self.move_mlp_to_device(hidden_states.device) kwargs = {} if attention_mask.numel() == 0: # We can't pass None between pipeline layers, so this signals that attention_mask should be None. kwargs['attention_mask'] = None else: # We have to pass attention mask between layers as float dtype, because in certain cases training hangs otherwise. So # now cast it back to int64 if we are using flash_attention_2. kwargs['attention_mask'] = attention_mask.to(torch.int64) if self.attn_implementation == 'flash_attention_2' else attention_mask kwargs['position_embeddings'] = (cos, sin) if self.pipeline_model[0].sampling_mode: kwargs['use_cache'] = True kwargs['past_key_value'] = self.pipeline_model[0].cache kwargs['cache_position'] = cache_position result = (self.orig(hidden_states, **kwargs)[0], attention_mask, cos, sin, cache_position, labels) if self.mlp_offloaded_to_cpu and not torch.is_grad_enabled(): self.move_mlp_to_cpu() return result def move_mlp_to_cpu(self): # If it's already been moved to CPU once, just set the data to avoid a transfer. if self.mlp_offloaded_to_cpu: set_data(self.orig.mlp.up_proj, self.cpu_up_proj) set_data(self.orig.mlp.down_proj, self.cpu_down_proj) set_data(self.orig.mlp.gate_proj, self.cpu_gate_proj) return move_data_to_device(self.orig.mlp.up_proj, 'cpu') move_data_to_device(self.orig.mlp.down_proj, 'cpu') move_data_to_device(self.orig.mlp.gate_proj, 'cpu') self.mlp_offloaded_to_cpu = True def move_mlp_to_device(self, device): self.cpu_up_proj = move_data_to_device(self.orig.mlp.up_proj, device) self.cpu_down_proj = move_data_to_device(self.orig.mlp.down_proj, device) self.cpu_gate_proj = move_data_to_device(self.orig.mlp.gate_proj, device) class Phi3DecoderLayerPipe(LlamaDecoderLayerPipe): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def move_mlp_to_cpu(self): if self.mlp_offloaded_to_cpu: set_data(self.orig.mlp.gate_up_proj, self.cpu_gate_up_proj) set_data(self.orig.mlp.down_proj, self.cpu_down_proj) return move_data_to_device(self.orig.mlp.gate_up_proj, 'cpu') move_data_to_device(self.orig.mlp.down_proj, 'cpu') self.mlp_offloaded_to_cpu = True def move_mlp_to_device(self, device): self.cpu_gate_up_proj = move_data_to_device(self.orig.mlp.gate_up_proj, device) self.cpu_down_proj = move_data_to_device(self.orig.mlp.down_proj, device) class MixtralDecoderLayerPipe(LlamaDecoderLayerPipe): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.num_experts_to_offload = self.pipeline_model[0].num_experts_to_offload def forward(self, inputs): def move_mlp_to_cpu_hook(grad): self.move_mlp_to_cpu() return None hidden_states, attention_mask, cos, sin, cache_position, labels, *input_router_logits = inputs if self.mlp_offloaded_to_cpu: if hidden_states.requires_grad: hidden_states.register_hook(move_mlp_to_cpu_hook) self.move_mlp_to_device(hidden_states.device) kwargs = {} if attention_mask.numel() == 0: # We can't pass None between pipeline layers, so this signals that attention_mask should be None. kwargs['attention_mask'] = None else: kwargs['attention_mask'] = attention_mask.to(torch.int64) if self.attn_implementation == 'flash_attention_2' else attention_mask kwargs['position_embeddings'] = (cos, sin) if self.pipeline_model[0].sampling_mode: kwargs['use_cache'] = True kwargs['past_key_value'] = self.pipeline_model[0].cache hidden_states, router_logits = self.orig(hidden_states, output_router_logits=True, **kwargs) # TODO: fix unsloth gradient checkpointing when we return router logits # router_logits = router_logits.to(torch.float32) # router_logits = input_router_logits + (router_logits,) # result = (hidden_states, attention_mask, cos, sin, labels, *router_logits) result = (hidden_states, attention_mask, cos, sin, cache_position, labels) if self.mlp_offloaded_to_cpu and not torch.is_grad_enabled(): self.move_mlp_to_cpu() return result def move_mlp_to_cpu(self): if self.mlp_offloaded_to_cpu: set_experts_data(self.orig.block_sparse_moe.experts, self.orig_data) return move_experts_to_device(self.orig.block_sparse_moe.experts, 'cpu', self.num_experts_to_offload) self.mlp_offloaded_to_cpu = True def move_mlp_to_device(self, device): self.orig_data = move_experts_to_device( self.orig.block_sparse_moe.experts, device, self.num_experts_to_offload ) class Gemma3InputLayer(nn.Module): def __init__(self, model): super().__init__() self._model = [model] self.embed_tokens = model.model.embed_tokens self.rotary_emb = model.model.rotary_emb self.rotary_emb_local = model.model.rotary_emb_local self.embedding_on_cpu = not self.model.train_config['full_fine_tune'] self.model.loader_util.load_state_dict_into_module(self) @property def model(self): return self._model[0] def forward(self, inputs): past_key_values = None cache_position = None use_cache = self.model.sampling_mode input_ids, attention_mask, labels = inputs[:3] device = input_ids.device if self.embedding_on_cpu: self.embed_tokens.to('cpu') input_ids = input_ids.to('cpu') inputs_embeds = self.embed_tokens(input_ids).to(device) if use_cache: past_key_values = self.model.cache past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) position_ids = cache_position.unsqueeze(0) attention_mask = self.model.model._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, None ) if attention_mask is None: # attention_mask can end up being None, which means use full causal attention. But with pipeline parallelism, # we can only pass tensors between layers. So make it an empty tensor, which will later be detected by the layers # and converted back to None. Note: this only works now because dynamic_shape=True in the pipeline engine. attention_mask = torch.tensor([], device=device) # Work around a very strange Deepspeed bug. The combination of PipelineModule dynamic_shape=True, attention_mask being # an integer dtype, and pipeline_stages>2 causes training (but not eval) to hang. So cast to float, and cast back to int # in the layer. if torch.is_tensor(attention_mask) and not torch.is_floating_point(attention_mask): attention_mask = attention_mask.to(inputs_embeds.dtype) hidden_states = inputs_embeds if self.model.model.config.model_type == 'gemma2': normalizer = torch.tensor(self.model.model.config.hidden_size**0.5, dtype=hidden_states.dtype) hidden_states = hidden_states * normalizer position_embeddings_global_cos, position_embeddings_global_sin = self.rotary_emb(hidden_states, position_ids) position_embeddings_local_cos, position_embeddings_local_sin = self.rotary_emb_local(hidden_states, position_ids) output = hidden_states, attention_mask, position_embeddings_global_cos, position_embeddings_global_sin, position_embeddings_local_cos, position_embeddings_local_sin, cache_position, labels # Deepspeed requirement. Float tensors must require grad. for tensor in output: if torch.is_floating_point(tensor): tensor.requires_grad_(True) return output class Gemma3DecoderLayerPipe(nn.Module): def __init__(self, pipeline_model, loader_util, orig): super().__init__() self.pipeline_model = [pipeline_model] self.orig = orig self.mlp_offloaded_to_cpu = False self.attn_implementation = pipeline_model.config._attn_implementation loader_util.load_state_dict_into_module(self) def forward(self, inputs): def move_mlp_to_cpu_hook(grad): self.move_mlp_to_cpu() return None hidden_states, attention_mask, position_embeddings_global_cos, position_embeddings_global_sin, position_embeddings_local_cos, position_embeddings_local_sin, cache_position, labels = inputs if self.mlp_offloaded_to_cpu: if hidden_states.requires_grad: hidden_states.register_hook(move_mlp_to_cpu_hook) self.move_mlp_to_device(hidden_states.device) kwargs = {} if attention_mask.numel() == 0: # We can't pass None between pipeline layers, so this signals that attention_mask should be None. kwargs['attention_mask'] = None else: kwargs['attention_mask'] = attention_mask.to(torch.int64) if self.attn_implementation == 'flash_attention_2' else attention_mask kwargs['position_embeddings_global'] = (position_embeddings_global_cos, position_embeddings_global_sin) kwargs['position_embeddings_local'] = (position_embeddings_local_cos, position_embeddings_local_sin) kwargs['cache_position'] = cache_position if self.pipeline_model[0].sampling_mode: kwargs['use_cache'] = True kwargs['past_key_value'] = self.pipeline_model[0].cache result = ( self.orig(hidden_states, **kwargs)[0], attention_mask, position_embeddings_global_cos, position_embeddings_global_sin, position_embeddings_local_cos, position_embeddings_local_sin, cache_position, labels ) if self.mlp_offloaded_to_cpu and not torch.is_grad_enabled(): self.move_mlp_to_cpu() return result def move_mlp_to_cpu(self): # If it's already been moved to CPU once, just set the data to avoid a transfer. if self.mlp_offloaded_to_cpu: set_data(self.orig.mlp.up_proj, self.cpu_up_proj) set_data(self.orig.mlp.down_proj, self.cpu_down_proj) set_data(self.orig.mlp.gate_proj, self.cpu_gate_proj) return move_data_to_device(self.orig.mlp.up_proj, 'cpu') move_data_to_device(self.orig.mlp.down_proj, 'cpu') move_data_to_device(self.orig.mlp.gate_proj, 'cpu') self.mlp_offloaded_to_cpu = True def move_mlp_to_device(self, device): self.cpu_up_proj = move_data_to_device(self.orig.mlp.up_proj, device) self.cpu_down_proj = move_data_to_device(self.orig.mlp.down_proj, device) self.cpu_gate_proj = move_data_to_device(self.orig.mlp.gate_proj, device) class Gemma3RMSNormPipe(nn.Module): def __init__(self, loader_util, orig): super().__init__() self.orig = orig loader_util.load_state_dict_into_module(self) def forward(self, inputs): hidden_states, *_, labels = inputs return self.orig(hidden_states), labels ================================================ FILE: models/models.py ================================================ import accelerate import torch import transformers from models.layers import ( InputLayer, LayerSpec, LlamaDecoderLayerPipe, LlamaRMSNormPipe, MixtralDecoderLayerPipe, MixtralOutputLayer, OutputLayer, Phi3DecoderLayerPipe, Gemma3InputLayer, Gemma3DecoderLayerPipe, Gemma3RMSNormPipe, ) from models.pipeline_model import PipelineModel from utils.utils import DTYPE_MAP DEFAULT_ATTN_IMPLEMENTATION = 'flash_attention_2' # A little bit of inheritance and MRO trickery since LlamaForCausalLM.__init__ only takes a # positional argument. We inherit PipelineModel first, but call LlamaForCausalLM init first, # and make sure PipelineModel doesn't have a super().__init__() call. class LlamaForCausalLMPipe(PipelineModel, transformers.LlamaForCausalLM): def __init__(self, config, quantization_config): model_config = transformers.LlamaConfig.from_pretrained(config['model']) model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION) torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')]) with accelerate.init_empty_weights(): transformers.LlamaForCausalLM.__init__(self, model_config) PipelineModel.__init__(self, config, quantization_config, model_config) torch.set_default_dtype(torch.float32) def to_layer_specs(self): embedding_relative_size = 4 embedding_on_cpu = not self.train_config['full_fine_tune'] result = [LayerSpec(InputLayer, self, _estimated_size=0 if embedding_on_cpu else embedding_relative_size)] for block in self.model.layers: result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block)) result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0)) result.append( LayerSpec( OutputLayer, self, self.loader_util, self.lm_head, loss_type=self.loss_type, focal_loss_gamma=self.focal_loss_gamma, tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None, _estimated_size=embedding_relative_size, ) ) return result class Qwen2ForCausalLMPipe(PipelineModel, transformers.Qwen2ForCausalLM): def __init__(self, config, quantization_config): model_config = transformers.Qwen2Config.from_pretrained(config['model']) model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION) torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')]) with accelerate.init_empty_weights(): transformers.Qwen2ForCausalLM.__init__(self, model_config) PipelineModel.__init__(self, config, quantization_config, model_config) torch.set_default_dtype(torch.float32) def to_layer_specs(self): result = [LayerSpec(InputLayer, self)] for block in self.model.layers: result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block)) result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0)) result.append( LayerSpec( OutputLayer, self, self.loader_util, self.lm_head, loss_type=self.loss_type, focal_loss_gamma=self.focal_loss_gamma, tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None, ) ) return result class CohereForCausalLMPipe(PipelineModel, transformers.CohereForCausalLM): def __init__(self, config, quantization_config): model_config = transformers.CohereConfig.from_pretrained(config['model']) model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION) torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')]) with accelerate.init_empty_weights(): transformers.CohereForCausalLM.__init__(self, model_config) PipelineModel.__init__(self, config, quantization_config, model_config) torch.set_default_dtype(torch.float32) def to_layer_specs(self): # the embedding table for this model is huge; load balance it better with some heuristics embedding_relative_size = 4 embedding_on_cpu = not self.train_config['full_fine_tune'] result = [LayerSpec(InputLayer, self, _estimated_size=1 if embedding_on_cpu else embedding_relative_size)] for block in self.model.layers: result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block)) result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0)) result.append( LayerSpec( OutputLayer, self, self.loader_util, self.lm_head, logit_scale=self.logit_scale, loss_type=self.loss_type, focal_loss_gamma=self.focal_loss_gamma, tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None, _estimated_size=embedding_relative_size, ) ) return result class Phi3ForCausalLMPipe(PipelineModel, transformers.Phi3ForCausalLM): def __init__(self, config, quantization_config): model_config = transformers.Phi3Config.from_pretrained(config['model']) model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION) torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')]) with accelerate.init_empty_weights(): transformers.Phi3ForCausalLM.__init__(self, model_config) PipelineModel.__init__(self, config, quantization_config, model_config) torch.set_default_dtype(torch.float32) def to_layer_specs(self): result = [LayerSpec(InputLayer, self)] for block in self.model.layers: result.append(LayerSpec(Phi3DecoderLayerPipe, self.loader_util, block)) result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0)) result.append( LayerSpec( OutputLayer, self, self.loader_util, self.lm_head, loss_type=self.loss_type, focal_loss_gamma=self.focal_loss_gamma, ) ) return result class Gemma2ForCausalLMPipe(PipelineModel, transformers.Gemma2ForCausalLM): def __init__(self, config, quantization_config): model_config = transformers.Gemma2Config.from_pretrained(config['model']) model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION) torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')]) with accelerate.init_empty_weights(): transformers.Gemma2ForCausalLM.__init__(self, model_config) PipelineModel.__init__(self, config, quantization_config, model_config) torch.set_default_dtype(torch.float32) def to_layer_specs(self): # the embedding table for this model is huge; load balance it better with some heuristics # this value optimized for LoRA, pipeline_stages=2 embedding_relative_size = 8 embedding_on_cpu = not self.train_config['full_fine_tune'] result = [LayerSpec(InputLayer, self, _estimated_size=1 if embedding_on_cpu else embedding_relative_size)] for block in self.model.layers: result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block)) result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0)) result.append( LayerSpec( OutputLayer, self, self.loader_util, self.lm_head, loss_type=self.loss_type, focal_loss_gamma=self.focal_loss_gamma, tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None, logit_softcapping=self.config.final_logit_softcapping, _estimated_size=embedding_relative_size, ) ) return result class MistralForCausalLMPipe(PipelineModel, transformers.MistralForCausalLM): def __init__(self, config, quantization_config): model_config = transformers.MistralConfig.from_pretrained(config['model']) model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION) torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')]) with accelerate.init_empty_weights(): transformers.MistralForCausalLM.__init__(self, model_config) PipelineModel.__init__(self, config, quantization_config, model_config) torch.set_default_dtype(torch.float32) def to_layer_specs(self): result = [LayerSpec(InputLayer, self)] for block in self.model.layers: result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block)) result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0)) result.append( LayerSpec( OutputLayer, self, self.loader_util, self.lm_head, loss_type=self.loss_type, focal_loss_gamma=self.focal_loss_gamma, ) ) return result class MixtralForCausalLMPipe(PipelineModel, transformers.MixtralForCausalLM): def __init__(self, config, quantization_config): model_config = transformers.MixtralConfig.from_pretrained(config['model']) model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION) torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')]) with accelerate.init_empty_weights(): transformers.MixtralForCausalLM.__init__(self, model_config) PipelineModel.__init__(self, config, quantization_config, model_config) torch.set_default_dtype(torch.float32) self.load_balancing_loss_coef = config.get('load_balancing_loss_coef', None) self.num_experts_to_offload = self.num_experts if 'offload_mlp_to_cpu' in config and isinstance(config['offload_mlp_to_cpu'], int): self.num_experts_to_offload = config['offload_mlp_to_cpu'] def to_layer_specs(self): result = [LayerSpec(InputLayer, self)] for block in self.model.layers: result.append(LayerSpec(MixtralDecoderLayerPipe, self, self.loader_util, block)) result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0)) result.append( LayerSpec( MixtralOutputLayer, self, self.loader_util, self.lm_head, load_balancing_loss_coef=self.load_balancing_loss_coef, num_experts=self.num_experts, num_experts_per_tok=self.num_experts_per_tok, loss_type=self.loss_type, focal_loss_gamma=self.focal_loss_gamma, ) ) return result class Gemma3ForCausalLMPipe(PipelineModel, transformers.Gemma3ForCausalLM): def __init__(self, config, quantization_config): model_config = transformers.Gemma3TextConfig.from_pretrained(config['model']) model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION) torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')]) with accelerate.init_empty_weights(): transformers.Gemma3ForCausalLM.__init__(self, model_config) PipelineModel.__init__(self, config, quantization_config, model_config) torch.set_default_dtype(torch.float32) def to_layer_specs(self): # the embedding table for this model is huge; load balance it better with some heuristics # this value optimized for LoRA, pipeline_stages=2 embedding_relative_size = 8 embedding_on_cpu = not self.train_config['full_fine_tune'] result = [LayerSpec(Gemma3InputLayer, self, _estimated_size=1 if embedding_on_cpu else embedding_relative_size)] for block in self.model.layers: result.append(LayerSpec(Gemma3DecoderLayerPipe, self, self.loader_util, block)) result.append(LayerSpec(Gemma3RMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0)) result.append( LayerSpec( OutputLayer, self, self.loader_util, self.lm_head, loss_type=self.loss_type, focal_loss_gamma=self.focal_loss_gamma, tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None, logit_softcapping=self.config.final_logit_softcapping, _estimated_size=embedding_relative_size, ) ) return result class Cohere2ForCausalLMPipe(PipelineModel, transformers.Cohere2ForCausalLM): def __init__(self, config, quantization_config): model_config = transformers.Cohere2Config.from_pretrained(config['model']) model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION) torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')]) with accelerate.init_empty_weights(): transformers.Cohere2ForCausalLM.__init__(self, model_config) PipelineModel.__init__(self, config, quantization_config, model_config) torch.set_default_dtype(torch.float32) def to_layer_specs(self): # the embedding table for this model is huge; load balance it better with some heuristics embedding_relative_size = 8 embedding_on_cpu = not self.train_config['full_fine_tune'] result = [LayerSpec(InputLayer, self, _estimated_size=2 if embedding_on_cpu else embedding_relative_size)] for block in self.model.layers: result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block)) result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0)) result.append( LayerSpec( OutputLayer, self, self.loader_util, self.lm_head, logit_scale=self.logit_scale, loss_type=self.loss_type, focal_loss_gamma=self.focal_loss_gamma, tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None, _estimated_size=embedding_relative_size, ) ) return result ================================================ FILE: models/pipeline_model.py ================================================ import os from collections import defaultdict from inspect import signature import re import accelerate import bitsandbytes as bnb import transformers from deepspeed.accelerator import get_accelerator from hqq.core import quantize as hqq_quantize from torch import nn from transformers.integrations import get_keys_to_not_convert from accelerate.utils import set_module_tensor_to_device import utils.hqq_utils as hqq_utils from utils.utils import is_main_process LANGUAGE_MODEL_WEIGHT_PREFIX_REGEX = r'^language_model\.' class PipelineModel(nn.Module): def __init__(self, config, quantization_config, model_config): if config['full_fine_tune'] and model_config.tie_word_embeddings: raise NotImplementedError('FFT is not supported for models with tied embeddings') self.train_config = config self.model_config = model_config self.modules_to_not_quantize = get_keys_to_not_convert(self) self.loader_util = LoaderUtil(config['model'], quantization_config, self.modules_to_not_quantize) self.loss_type = config.get('loss_type', 'cross_entropy_loss').lower() if rl_config := config.get('rl', None): self.loss_type = rl_config['method'] self.focal_loss_gamma = config.get('focal_loss_gamma', 0) if self.focal_loss_gamma > 0 and is_main_process(): print(f"Optimizing using '{self.loss_type}' with gamma={self.focal_loss_gamma}") self.dpo_reference_mode = False self.sampling_mode = False for name, p in self.named_parameters(): p.original_name = name # need to override this method def to_layer_specs(self): raise NotImplementedError() def set_dpo_reference_mode(self, dpo_reference_mode): self.dpo_reference_mode = dpo_reference_mode def set_sampling_mode(self, sampling_mode): self.sampling_mode = sampling_mode # Reset cache when sampling mode is modified. This ensures it's initialized and also clears memory at the end. self.cache_dict = defaultdict(transformers.DynamicCache) # We could try to use static cache at some point. During early testing with relatively short sequence lengths, # it was the same sampling speed as DynamicCache. Note: will need to pass cache_position in transformer layer # if using StaticCache. # def make_static_cache(): # return transformers.StaticCache( # self.model_config, # max_batch_size=1, # max_cache_len=1024, # device='cuda', # dtype=self.dtype, # ) # self.cache_dict = defaultdict(make_static_cache) def set_cache(self, micro_batch_id): self.cache = self.cache_dict[micro_batch_id] def _partial_module_name_match(full_name, list_to_match): return any(key in full_name for key in list_to_match) def _replace_with_quantized_linear(parent_modules_map, name, full_name, quantization_config): if isinstance(quantization_config, transformers.BitsAndBytesConfig): _replace_with_bnb_linear(parent_modules_map, name, full_name, quantization_config) elif isinstance(quantization_config, hqq_utils.CustomHQQConfig): _replace_with_hqq_linear(parent_modules_map, name, full_name, quantization_config) else: raise NotImplementedError(f'Quantization config not implemented: {quantization_config}') def _replace_with_bnb_linear(parent_modules_map, name, full_name, quantization_config): """Replace a Linear layer with a BNB quantized version.""" if quantization_config.llm_int8_skip_modules is not None and _partial_module_name_match( full_name, quantization_config.llm_int8_skip_modules ): return module = parent_modules_map[name] with accelerate.init_empty_weights(): if isinstance(module, nn.Conv1d): in_features, out_features = module.weight.shape else: in_features = module.in_features out_features = module.out_features if quantization_config.quantization_method() == 'llm_int8': parent_modules_map[name] = bnb.nn.Linear8bitLt( in_features, out_features, module.bias is not None, has_fp16_weights=quantization_config.llm_int8_has_fp16_weight, threshold=quantization_config.llm_int8_threshold, ) else: extra_kwargs = ( {'quant_storage': quantization_config.bnb_4bit_quant_storage} if 'quant_storage' in list(signature(bnb.nn.Linear4bit).parameters) else {} ) parent_modules_map[name] = bnb.nn.Linear4bit( in_features, out_features, module.bias is not None, quantization_config.bnb_4bit_compute_dtype, compress_statistics=quantization_config.bnb_4bit_use_double_quant, quant_type=quantization_config.bnb_4bit_quant_type, **extra_kwargs, ) # Store the module class in case we need to transpose the weight later parent_modules_map[name].source_cls = type(module) # Force requires grad to False to avoid unexpected errors parent_modules_map[name].requires_grad_(False) def _replace_with_hqq_linear(parent_modules_map, name, full_name, quantization_config): """Replace a Linear layer with a HQQ quantized version.""" if _partial_module_name_match(full_name, quantization_config.skip_modules): return module = parent_modules_map[name] quant_config_dict = quantization_config.get_dict(full_name) hqq_linear = hqq_quantize.HQQLinear( module, quant_config=quant_config_dict, compute_dtype=quantization_config.compute_dtype, device=module.weight.device, initialize=True, del_orig=True, ) # Quantization itself uses a decent amount of VRAM. Temporarily move each quantized parameter to the CPU as we # finish, so the quant process doesn't OOM. Deepspeed will move everything to the correct device later. hqq_linear.W_q.data = hqq_linear.W_q.data.to('cpu') # Store the module class in case we need to transpose the weight later hqq_linear.source_cls = type(module) # Force requires grad to False to avoid unexpected errors hqq_linear.requires_grad_(False) parent_modules_map[name] = hqq_linear # modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/bitsandbytes.py def _recursively_replace_with_quantized_linear( model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, ): """ Returns the converted model and a boolean that indicates if the conversion has been successful or not. """ for name, module in model.named_children(): if current_key_name is None: current_key_name = [] current_key_name.append(name) if (isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d)) and name not in modules_to_not_convert: # Check if the current key is not in the `modules_to_not_convert` current_key_name_str = '.'.join(current_key_name) if not any( (key + '.' in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert ): _replace_with_quantized_linear(model._modules, name, current_key_name_str, quantization_config) # copy over the original_name attribute we added earlier (needed for loading weights) for orig_name, orig_p in module.named_parameters(): if hasattr(orig_p, 'original_name'): for new_name, new_p in model._modules[name].named_parameters(): if new_name == orig_name: new_p.original_name = orig_p.original_name if len(list(module.children())) > 0: _recursively_replace_with_quantized_linear( module, modules_to_not_convert, current_key_name, quantization_config, ) # Remove the last key for recursion current_key_name.pop(-1) class LoaderUtil: def __init__(self, model_path, quantization_config, modules_to_not_quantize): self.model_path = model_path self.quantization_config = quantization_config self.modules_to_not_quantize = modules_to_not_quantize self.local_rank = int(os.environ.get('LOCAL_RANK', None)) assert self.local_rank is not None self.device = get_accelerator().device_name(self.local_rank) index_file = os.path.join(model_path, transformers.utils.SAFE_WEIGHTS_INDEX_NAME) if os.path.exists(index_file): checkpoint_files, checkpoint_metadata = transformers.utils.hub.get_checkpoint_shard_files( model_path, index_file, local_files_only=True ) self.checkpoint_metadata = checkpoint_metadata else: self.checkpoint_metadata = None self.loaded_state_dict = None def get_partial_state_dict(self, leaf_file): if self.loaded_state_dict is None or leaf_file != self.loaded_state_dict[0]: print(f'loading checkpoint file {leaf_file}') state_dict = transformers.modeling_utils.load_state_dict(os.path.join(self.model_path, leaf_file)) state_dict = {re.sub(LANGUAGE_MODEL_WEIGHT_PREFIX_REGEX, '', k): v for k, v in state_dict.items()} self.loaded_state_dict = (leaf_file, state_dict) return self.loaded_state_dict[1] def maybe_quantize(self, module): if self.quantization_config is None: return modules_to_not_convert = self.modules_to_not_quantize if not isinstance(modules_to_not_convert, list): modules_to_not_convert = [modules_to_not_convert] _recursively_replace_with_quantized_linear( module, modules_to_not_convert=modules_to_not_convert, quantization_config=self.quantization_config ) # Make sure to set this or PEFT (and probably other things) will break in strange ways. # We only need this because we do the loading and quanting ourselves. self.is_loaded_in_4bit = True def load_state_dict_into_module(self, module): print(f'load params into module {type(module)}') if isinstance(self.quantization_config, transformers.BitsAndBytesConfig): # bnb needs to replace with quantized linear before weights are loaded self.maybe_quantize(module) param_renaming_map = {p.original_name: new_name for new_name, p in module.named_parameters()} expected_keys = [p.original_name for p in module.parameters()] # If we have any extra attributes on the parameter, loading with BNB 4bit params breaks, so delete them. for p in module.parameters(): del p.original_name if self.checkpoint_metadata is not None: weight_map = self.checkpoint_metadata['weight_map'] weight_map = {re.sub(LANGUAGE_MODEL_WEIGHT_PREFIX_REGEX, '', k): v for k, v in weight_map.items()} needed_checkpoint_files = {weight_map[key.replace('orig.', '')] for key in expected_keys} else: needed_checkpoint_files = ['model.safetensors'] for checkpoint_file in needed_checkpoint_files: state_dict = self.get_partial_state_dict(checkpoint_file) renamed_state_dict = {param_renaming_map[k]: v for k, v in state_dict.items() if k in param_renaming_map} for name, param in module.named_parameters(): if name in renamed_state_dict: set_module_tensor_to_device(module, name, device='cpu', value=renamed_state_dict[name]) module.to(self.device) if not isinstance(self.quantization_config, transformers.BitsAndBytesConfig): self.maybe_quantize(module) ================================================ FILE: pyproject.toml ================================================ [tool.black] # Only used by `hf-doc-builder´. line-length = 119 target-version = ['py38'] [tool.ruff] target-version = "py38" line-length = 119 extend-exclude = ["axolotl"] [tool.ruff.format] quote-style = "single" [tool.ruff.lint] extend-select = [ "C", # Complexity "E", # PEP8 errors "F", # PEP8 formatting "I", # Import sorting "UP", # Pyupgrade upgrades "W", # PEP8 warnings # "PT009", # Pytest assertions ] ignore = [ "C901", # Function too complex "E501", # Line length (handled by ruff-format) "E741", # Allow the letter 'l' "UP007", # X | Y style Unions # "PT009", # self.assertEqual etc ] [tool.ruff.lint.isort] lines-after-imports = 2 known-first-party = ["peft"] [tool.pytest] doctest_optionflags = [ "NORMALIZE_WHITESPACE", "ELLIPSIS", "NUMBER", ] [tool.pytest.ini_options] addopts = "--cov=src/peft --cov-report=term-missing --durations=10" markers = [ "single_gpu_tests: tests that run on a single GPU", "multi_gpu_tests: tests that run on multiple GPUs", "regression: whether to run regression suite test", "bitsandbytes: select bitsandbytes integration tests" ] ================================================ FILE: requirements.txt ================================================ torch torchvision torchaudio accelerate bitsandbytes datasets deepspeed packaging peft safetensors scipy sentencepiece tensorboard toml jsonlines transformers flash-attn pyyaml tqdm triton hqq torch-optimi ruff # minimum Axolotl dependencies for what we use addict optimum pynvml evaluate wandb numba colorama trl mlflow termcolor fastcore ================================================ FILE: tools/convert_dpo_dataset_to_chat_format.py ================================================ # Convert a DPO dataset with prompt, chosen, rejected fields into chat format. # Usage: python convert_dpo_dataset_to_chat_format.py hf_username/some_dataset path/to/output/directory import os import sys from pathlib import Path import datasets dataset_path, converted_path = sys.argv[1:] dataset = datasets.load_dataset(dataset_path) def convert(x): prompt = x['prompt'] chosen = [{'role': 'user', 'content': prompt}, {'role': 'assistant', 'content': x['chosen']}] rejected = [{'role': 'user', 'content': prompt}, {'role': 'assistant', 'content': x['rejected']}] return {'chosen': chosen, 'rejected': rejected} num_proc = min(64, os.cpu_count()) dataset = dataset.map(convert, num_proc=num_proc) for name, split in dataset.items(): filepath = Path(converted_path) / f'{name}.json' split.to_json(filepath) ================================================ FILE: tools/convert_ds_checkpoint_to_lora.py ================================================ # Very hacky script to convert pipeline parallel Deepspeed checkpoints into a saved lora model. # I originally wrote this because I screwed up the lora model saving initially, and needed a # way to turn the training checkpoints into saved lora models to test them. import os.path import re from glob import glob import torch def convert_ds_checkpoint_to_lora(ds_checkpoint_dir, lora_output_dir): layer_checkpoint_files = glob(os.path.join(ds_checkpoint_dir, 'layer_*-model_states.pt')) combined_state_dict = {} for path in layer_checkpoint_files: match = re.fullmatch('layer_(.+)-model_states.pt', os.path.basename(path)) layer_idx = int(match.group(1)) - 2 state_dict = torch.load(path) for name, weight in state_dict.items(): converted_name = name.replace('orig', f'base_model.model.model.layers.{layer_idx}').replace('.default', '') combined_state_dict[converted_name] = weight os.makedirs(lora_output_dir, exist_ok=True) torch.save(combined_state_dict, os.path.join(lora_output_dir, 'adapter_model.bin')) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--input') parser.add_argument('--output') args = parser.parse_args() convert_ds_checkpoint_to_lora(args.input, args.output) ================================================ FILE: tools/merge_lora.py ================================================ # Usage: python merge_lora.py input_path lora_path output_path # Output path is created if it doesn't exist import argparse import os import re import shutil from pathlib import Path import safetensors import torch from tqdm import tqdm import peft parser = argparse.ArgumentParser() parser.add_argument('input_path', type=str, help='The path to the input directory.') parser.add_argument('lora_path', type=str, help='The path to the LoRA directory.') parser.add_argument('output_path', type=str, help='The path to the output directory.') parser.add_argument('--no-gpu', action='store_true', help='Use CPU for merging.') args = parser.parse_args() input_path, lora_path, output_path = Path(args.input_path), Path(args.lora_path), Path(args.output_path) os.makedirs(output_path, exist_ok=True) lora_config = peft.LoraConfig.from_json_file(lora_path / 'adapter_config.json') scale = lora_config['lora_alpha'] / lora_config['r'] device = 'cpu' if args.no_gpu else 'cuda' print('Loading LoRA model...') # Check if we have adapter_model.bin or adapter_model.safetensors if (lora_path / 'adapter_model.safetensors').exists(): lora_state = safetensors.torch.load_file(lora_path / 'adapter_model.safetensors') if not args.no_gpu: # Move mapped entries to cuda for key, value in tqdm(lora_state.items()): lora_state[key] = value.to('cuda') else: lora_state = torch.load(lora_path / 'adapter_model.bin', map_location=device) def find_lora_weights(key): lora_A = None lora_B = None for lora_key, lora_weight in lora_state.items(): if key.strip('.weight') in lora_key: if 'lora_A' in lora_key: lora_A = lora_weight elif 'lora_B' in lora_key: lora_B = lora_weight else: raise RuntimeError() assert not ((lora_A is None) ^ (lora_B is None)) return lora_A, lora_B shards = [] for shard in input_path.glob('model*.safetensors'): shards.append(shard) print('Copying unmergable files to output') for filepath in input_path.glob('*'): if filepath in shards: continue filepath = Path(filepath) if filepath.is_dir(): continue if filepath.suffix == '.gguf': # Skip unrelated stray quantizations continue if filepath.suffix == '.safetensors': # Consolidated, possibly continue print(f'copying {filepath.name} to output') shutil.copy(filepath, output_path) print('Merging and copying state_dict to output') found = 0 for shard in (pbar := tqdm(shards)): tensors = {} with safetensors.safe_open(shard, framework='pt', device=device) as f: metadata = f.metadata() for key in f.keys(): lora_key = re.sub(r'^language_model\.', '', key) tensor = f.get_tensor(key) lora_A, lora_B = find_lora_weights(lora_key) if lora_A is not None: found += 1 pbar.set_description(f'found lora weights for {key}: {lora_A.size()}, {lora_B.size()}') old_type = tensor.dtype tensor = tensor.to(torch.float32) tensor += scale * lora_B.to(torch.float32) @ lora_A.to(torch.float32) tensor = tensor.to(old_type) tensors[key] = tensor safetensors.torch.save_file(tensors, output_path / shard.name, metadata=metadata) print(f"Applied LoRA to {found} tensors.") ================================================ FILE: tools/test_sampling.py ================================================ # deepspeed --num_gpus=1 --module tools.test_sampling --config ~/code/qlora-pipe-configs/config_8b_dpo.toml import argparse import json import os.path import bitsandbytes import deepspeed import toml import transformers import utils.engine as engine from train import load_pipeline_model_with_lora from utils.utils import DTYPE_MAP, is_main_process PROMPT_FORMAT = """<|start_header_id|>user<|end_header_id|> {}<|eot_id|><|start_header_id|>assistant<|end_header_id|> """ PROMPTS = [ 'Where is Popeye Village located?', 'What is the name of Sweden in Swedish?', ] parser = argparse.ArgumentParser() parser.add_argument('--config', help='Path to TOML configuration file.') parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher') parser = deepspeed.add_config_arguments(parser) args = parser.parse_args() if __name__ == '__main__': with open(args.config) as f: config = toml.load(f) config['full_fine_tune'] = True if hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None: # engine.initialize() will load deepspeed config from args ds_config = None else: # The necessary ds_config fields are taken from the TOML config file. ds_config = { 'train_micro_batch_size_per_gpu': config.get('micro_batch_size_per_gpu', 1), 'gradient_accumulation_steps': config.get('gradient_accumulation_steps', 1), 'gradient_clipping': config.get('gradient_clipping', 1.0), 'steps_per_print': config.get('steps_per_print', 1), } deepspeed.init_distributed() with open(os.path.join(config['model'], 'config.json')) as f: model_config = json.load(f) model_type = model_config.get('model_type', 'llama') tokenizer = transformers.AutoTokenizer.from_pretrained( config['model'], local_files_only=True, model_max_length=int(1e30), padding_side='left', ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Ugly hack so we can move quantized models from GPU to CPU, and back to GPU again without triggering quantization a second time. bnb_cuda_old = bitsandbytes.nn.modules.Params4bit.cuda def bnb_cuda_hijack(self, device): if getattr(self, 'already_quantized', False): self.data = self.data.to(device) self.quant_state.to(device) return self self.already_quantized = True return bnb_cuda_old(self, device) bitsandbytes.nn.modules.Params4bit.cuda = bnb_cuda_hijack pipeline_model, lora_model, lora_config = load_pipeline_model_with_lora(config, model_type) kwargs = {} if sampling_settings := config.get('sampling', None): for k, v in sampling_settings.items(): kwargs['sampling_' + k] = v model_engine, _ = engine.initialize( args=args, model=pipeline_model, lora_model=lora_model, config=ds_config, tokenizer=tokenizer, **kwargs, ) weight_dtype = DTYPE_MAP[config.get('lora_weight_dtype', config.get('model_weight_dtype', 'float32'))] model_engine.communication_data_type = weight_dtype prompts = [PROMPT_FORMAT.format(prompt) for prompt in PROMPTS] #prompts = [[PROMPT_FORMAT.format(prompt) for prompt in PROMPTS]] outputs = model_engine.sample_batch(prompts, max_new_tokens=500) if is_main_process(): for text in outputs: print(text) print('-' * 80) ================================================ FILE: train.py ================================================ import argparse import glob import itertools import json import os import shutil import time from contextlib import contextmanager from datetime import datetime, timezone import bitsandbytes import deepspeed import toml import torch import transformers from deepspeed.runtime.pipe.module import LayerSpec from hqq.core import quantize as hqq_quantize from torch.utils.tensorboard import SummaryWriter import utils.dataloader as dataloader import utils.engine as engine import utils.hqq_utils as hqq_utils import models.models as models import utils.unsloth_utils as unsloth_utils from utils.dataset_utils import load_datasets from peft import LoraConfig, get_peft_model from peft.optimizers import create_loraplus_optimizer from utils.saver import Saver from utils.utils import DTYPE_MAP, is_main_process parser = argparse.ArgumentParser() parser.add_argument('--config', help='Path to TOML configuration file.') parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher') parser.add_argument('--debug_dataset', type=int, help='print out this many training examples and then quit') parser.add_argument( '--resume_from_checkpoint', action='store_true', default=None, help='resume training from the most recent checkpoint', ) parser.add_argument('--no_quantiles', action='store_true', help='suppress output of quantile metrics') parser = deepspeed.add_config_arguments(parser) args = parser.parse_args() def print_model_info(model): if not is_main_process(): return print(model) for name, module in model.named_modules(): print(f'{type(module)}: {name}') for pname, p in module.named_parameters(recurse=False): print(pname) print(p.dtype) print(p.device) print(p.requires_grad) print() def set_config_defaults(config): config['full_fine_tune'] = config.get('full_fine_tune', False) config['load_in_4bit'] = config.get('load_in_4bit', False) def get_most_recent_run_dir(output_dir): return sorted(glob.glob(os.path.join(output_dir, '*')))[-1] def write_metrics(tb_writer, prefix, metrics, step): loss = metrics[0].mean().item() tb_writer.add_scalar(f'{prefix}/loss', loss, step) if len(metrics) > 1: losses = metrics[1].view(-1) positive_losses = losses > 0 tb_writer.add_histogram(f'{prefix}/log_loss_hist', torch.log(losses[positive_losses]), step) if not args.no_quantiles: sorted_losses, sorted_losses_idx = torch.sort(losses) quantiles = torch.tensor( [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.96, 0.97, 0.98, 0.99, 0.999], dtype=torch.float32 ).to(losses.device) quantiles_idx = [int(len(losses) * quantile) for quantile in quantiles] loss_quantiles = [sorted_losses[i] for i in quantiles_idx] for quantile, value in zip(quantiles, loss_quantiles): tb_writer.add_scalar(f'{prefix}/loss_quantile_{quantile:.3f}', value, step) if len(metrics) > 2: hidden_norm_avg = metrics[2].mean().item() tb_writer.add_scalar(f'{prefix}/hidden_norm_avg', hidden_norm_avg, step) hidden_state_norms = metrics[2].view(-1) tb_writer.add_histogram(f'{prefix}/hidden_norm_hist', hidden_state_norms, step) if len(metrics) > 3: entropy = metrics[3].view(-1) tb_writer.add_scalar(f'{prefix}/entropy', entropy.mean().item(), step) if not args.no_quantiles: assert entropy.size() == losses.size(), (entropy.size(), losses.size()) sorted_entropy = entropy[sorted_losses_idx] entropy_quantiles = [] for i, j in itertools.zip_longest(quantiles_idx, quantiles_idx[1:]): entropy_quantiles.append(sorted_entropy[i:j].mean()) for quantile, value in zip(quantiles, entropy_quantiles): tb_writer.add_scalar(f'{prefix}/entropy_quantile_{quantile:.3f}', value, step) if len(metrics) > 4: normalised_entropy = metrics[4].view(-1) tb_writer.add_scalar(f'{prefix}/normalised_entropy', normalised_entropy.mean().item(), step) if not args.no_quantiles: assert normalised_entropy.size() == losses.size() sorted_normalised_entropy = normalised_entropy[sorted_losses_idx] normalised_entropy_quantiles = [] for i, j in itertools.zip_longest(quantiles_idx, quantiles_idx[1:]): normalised_entropy_quantiles.append(sorted_normalised_entropy[i:j].mean()) for quantile, value in zip(quantiles, normalised_entropy_quantiles): tb_writer.add_scalar(f'{prefix}/normalised_entropy_quantile_{quantile:.3f}', value, step) if len(metrics) > 5: log_likelihood = metrics[5].mean() tb_writer.add_scalar(f'{prefix}/log_likelihood', log_likelihood.item(), step) likelihood = torch.exp(-log_likelihood).item() tb_writer.add_scalar(f'{prefix}/likelihood', likelihood, step) perplexity = torch.exp(log_likelihood).item() tb_writer.add_scalar(f'{prefix}/perplexity', perplexity, step) if len(metrics) > 6: mcfaddens_pseudo_r2 = metrics[6].mean() tb_writer.add_scalar(f'{prefix}/mcfaddens_pseudo_r2', mcfaddens_pseudo_r2.item(), step) if len(metrics) > 7: tb_writer.add_scalar(f'{prefix}/top1_accuracy', metrics[7].mean().item(), step) tb_writer.add_scalar(f'{prefix}/top5_accuracy', metrics[8].mean().item(), step) tb_writer.add_scalar(f'{prefix}/top20_accuracy', metrics[9].mean().item(), step) if len(metrics) > 10: tb_writer.add_scalar(f'{prefix}/load_balancing_loss', metrics[10].mean().item(), step) if len(metrics) > 11: tb_writer.add_scalar(f'{prefix}/alternate_load_balancing_loss', metrics[11].mean().item(), step) return loss def evaluate_single(model_engine, name, eval_dataloader, tb_writer, step, eval_gradient_accumulation_steps): orig_micro_batches = model_engine.micro_batches model_engine.micro_batches = eval_gradient_accumulation_steps iterator = iter(eval_dataloader) all_metrics = None while True: metrics = model_engine.eval_batch(iterator) eval_dataloader.sync_epoch() if all_metrics is None: all_metrics = [[] for _ in range(len(metrics))] if eval_dataloader.epoch == 2: break for i, metric in enumerate(metrics): all_metrics[i].append(metric) eval_dataloader.reset() model_engine.micro_batches = orig_micro_batches eval_metrics = [torch.cat(metric_list) for metric_list in all_metrics] loss = None if is_main_process(): loss = write_metrics(tb_writer, f'eval/{name}', eval_metrics, step) return loss def evaluate(model_engine, eval_dataloaders, tb_writer, step, eval_gradient_accumulation_steps): if is_main_process(): print('Running eval') start = time.time() loss = [] for name, eval_dataloader in eval_dataloaders.items(): loss_or_none = evaluate_single( model_engine, name, eval_dataloader, tb_writer, step, eval_gradient_accumulation_steps ) if loss_or_none is not None: loss.append(loss_or_none) duration = time.time() - start if is_main_process(): tb_writer.add_scalar('eval/eval_time_sec', duration, step) return sum(loss) / len(loss) if len(loss) > 0 else None def apply_max_norm_regularization(model, config): # modifed from https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py A_keys = [] B_keys = [] norms = [] keys_scaled = 0 lora_scale = config['lora_alpha'] / config['lora_rank'] state_dict = model.state_dict() for key in state_dict.keys(): if 'lora_A' in key: A_keys.append(key) B_keys.append(key.replace('lora_A', 'lora_B')) for i in range(len(A_keys)): A = state_dict[A_keys[i]] B = state_dict[B_keys[i]] W = B @ A W *= lora_scale if 'scale_weight_norms' in config: max_norm = config['scale_weight_norms'] norm = W.norm().clamp(min=max_norm / 2) desired = torch.clamp(norm, max=max_norm) ratio = desired.cpu() / norm.cpu() sqrt_ratio = ratio**0.5 if ratio != 1: keys_scaled += 1 state_dict[A_keys[i]] *= sqrt_ratio state_dict[B_keys[i]] *= sqrt_ratio else: ratio = 1.0 scalednorm = W.norm() * ratio norms.append(scalednorm.item()) if len(norms) > 0: norms = torch.tensor(norms, dtype=torch.float32) if torch.any(torch.isnan(norms)): raise RuntimeError('NaN detected in norms, probably some/all weights are NaN') avg_norm = sum(norms) / len(norms) max_norm = max(norms) else: avg_norm = 0 max_norm = 0 return keys_scaled, avg_norm, max_norm, norms def parse_layers_to_transform(spec): parts = spec.split(',') result = [] for part in parts: start, stop = part.split(':') result.extend(range(int(start), int(stop) + 1)) return result @contextmanager def one_at_a_time(): for i in range(int(os.environ['LOCAL_SIZE'])): if i == int(os.environ['LOCAL_RANK']): yield deepspeed.comm.barrier() def load_pipeline_model_with_lora(config, model_type): full_fine_tune = config['full_fine_tune'] if config.get('quantization', None): assert not full_fine_tune no_quant_modules = ['lm_head'] if model_type == 'mixtral': # the expert routing weights are tiny and probably important, don't quantize no_quant_modules.append('gate') if bnb_quant_config := config['quantization'].get('bnb', None): if bnb_compute_dtype := bnb_quant_config.get('bnb_4bit_compute_dtype', None): bnb_quant_config['bnb_4bit_compute_dtype'] = DTYPE_MAP[bnb_compute_dtype] if 'bnb_4bit_quant_type' not in bnb_quant_config: # Always want to default to nf4 if not specified. bnb_quant_config['bnb_4bit_quant_type'] = 'nf4' if llm_int8_skip_modules := bnb_quant_config.get('llm_int8_skip_modules', None): no_quant_modules.extend(llm_int8_skip_modules) no_quant_modules = list(set(no_quant_modules)) bnb_quant_config['llm_int8_skip_modules'] = no_quant_modules quantization_config = transformers.BitsAndBytesConfig(**bnb_quant_config) elif hqq_quant_config := config['quantization'].get('hqq', None): quantization_config = hqq_utils.CustomHQQConfig(**hqq_quant_config) # Use ATEN backend if possible, else PYTORCH. PYTORCH_COMPILE was only a tiny bit faster, and requires triton nightly. hqq_quantize.HQQLinear.set_backend( hqq_quantize.HQQBackend.ATEN if quantization_config.use_aten() else hqq_quantize.HQQBackend.PYTORCH ) else: raise NotImplementedError('Invalid quantization config') if is_main_process(): print(f'Quantization config: {quantization_config}') else: quantization_config = None if model_type == 'llama': model = models.LlamaForCausalLMPipe(config, quantization_config=quantization_config) elif model_type == 'mixtral': model = models.MixtralForCausalLMPipe(config, quantization_config=quantization_config) elif model_type == 'qwen2': model = models.Qwen2ForCausalLMPipe(config, quantization_config=quantization_config) elif model_type == 'cohere': model = models.CohereForCausalLMPipe(config, quantization_config=quantization_config) elif model_type == 'phi3': model = models.Phi3ForCausalLMPipe(config, quantization_config=quantization_config) elif model_type == 'gemma2': model = models.Gemma2ForCausalLMPipe(config, quantization_config=quantization_config) elif model_type == 'mistral' or model_type == 'mistral3': model = models.MistralForCausalLMPipe(config, quantization_config=quantization_config) elif model_type == 'gemma3': model = models.Gemma3ForCausalLMPipe(config, quantization_config=quantization_config) elif model_type == 'cohere2': model = models.Cohere2ForCausalLMPipe(config, quantization_config=quantization_config) else: raise NotImplementedError(f'model_type {model_type} is not implemented') # CAREFUL! The "primary" layers of the model have to have 'decoderlayer' in them for # activation checkpointing to automatically work correctly. layers = model.to_layer_specs() checkpointable_layers = set() for layer in layers: if isinstance(layer, LayerSpec) and 'decoderlayer' in layer.typename.__name__.lower(): checkpointable_layers.add(layer.typename.__name__) checkpointable_layers = list(checkpointable_layers) partition_method = 'estimated_size' if config['activation_checkpointing']: # NOTE: must use a reentrant checkpointing function for MLP offloading to work. if config['activation_checkpointing'] == 'unsloth': checkpoint_func = unsloth_utils.unsloth_checkpoint elif config['activation_checkpointing'] == 'cpu': deepspeed.checkpointing.configure(None, checkpoint_in_cpu=True) checkpoint_func = deepspeed.checkpointing.checkpoint else: checkpoint_func = deepspeed.checkpointing.checkpoint pipeline_model = engine.CustomPipelineModule( layers=layers, num_stages=config['pipeline_stages'], activation_checkpoint_interval=1, checkpointable_layers=checkpointable_layers, activation_checkpoint_func=checkpoint_func, partition_method=partition_method, use_column_major_topology=config.get('use_column_major_topology', False), model=model, dynamic_shape=True, ) else: pipeline_model = engine.CustomPipelineModule( layers=layers, num_stages=config['pipeline_stages'], partition_method=partition_method, use_column_major_topology=config.get('use_column_major_topology', False), ) target_modules = config['target_modules'] if 'target_modules' in config else 'all-linear' if full_fine_tune: lora_model = None lora_config = None for name, p in model.named_parameters(): p.original_name = name if isinstance(target_modules, list): for name, p in model.named_parameters(): if not any(target in name for target in config['target_modules']): p.requires_grad = False print(f'not training {name} because it is not present in target_modules') else: layers_to_transform = ( parse_layers_to_transform(config['layers_to_transform']) if 'layers_to_transform' in config else None ) lora_config = LoraConfig( r=config['lora_rank'], lora_alpha=config['lora_alpha'], target_modules=target_modules, modules_to_save=config['modules_to_save'] if 'modules_to_save' in config else [], lora_dropout=config['lora_dropout'] if 'lora_dropout' in config else 0, layers_to_transform=layers_to_transform, bias='none', task_type='CAUSAL_LM', use_dora=config.get('use_dora', False), ) lora_model = get_peft_model(model, lora_config) # If the underlying weights are floats, the lora weights have already been # cast to the same dtype, so we need to change the dtype here. for p in lora_model.parameters(): if p.requires_grad: p.data = p.data.to(DTYPE_MAP[config.get('lora_weight_dtype', 'float32')]) lora_model.model.config.use_cache = False for name, p in lora_model.named_parameters(): p.original_name = name return pipeline_model, lora_model, lora_config if __name__ == '__main__': # TODO: if resuming from checkpoint, probably should read all config files from checkpoint dir # rather than assume they are unchanged on the command line with open(args.config) as f: config = toml.load(f) set_config_defaults(config) if hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None: # engine.initialize() will load deepspeed config from args ds_config = None else: # The necessary ds_config fields are taken from the TOML config file. ds_config = { 'train_micro_batch_size_per_gpu': config.get('micro_batch_size_per_gpu', 1), 'gradient_accumulation_steps': config.get('gradient_accumulation_steps', 1), 'gradient_clipping': config.get('gradient_clipping', 1.0), 'steps_per_print': config.get('steps_per_print', 1), } resume_from_checkpoint = ( args.resume_from_checkpoint if args.resume_from_checkpoint is not None else config['resume_from_checkpoint'] if 'resume_from_checkpoint' in config else False ) deepspeed.init_distributed() with open(os.path.join(config['model'], 'config.json')) as f: model_config = json.load(f) model_type = model_config.get('model_type', 'llama') # Pad on left to support training techniques that involve sampling from the model. tokenizer = transformers.AutoTokenizer.from_pretrained( config['model'], local_files_only=True, model_max_length=int(1e30), padding_side='left' ) # TODO: do we want to do this with cohere models? By default the EOS token is <|END_OF_TURN_TOKEN|> # if model_type == 'cohere': # tokenizer.eos_token = '' if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token train_data, eval_data_map = load_datasets(config, tokenizer) if args.debug_dataset: if is_main_process(): for i, item in enumerate(iter(train_data)): print('input_ids:') print(item['input_ids']) print('decoded input_ids:') print(tokenizer.decode(item['input_ids'])) print('attention_mask:') print(item['attention_mask']) print('labels:') print(item['labels']) if 'rejected_input_ids' in item: print('rejected_input_ids:') print(item['rejected_input_ids']) print('decoded rejected_input_ids:') print(tokenizer.decode(item['rejected_input_ids'])) print('rejected_attention_mask:') print(item['rejected_attention_mask']) print('rejected_labels:') print(item['rejected_labels']) print('-' * 80) if i >= args.debug_dataset - 1: break quit() # for testing # train_data = train_data.select(list(range(100))) # if this is a new run, create a new dir for it if not resume_from_checkpoint and is_main_process(): run_dir = os.path.join(config['output_dir'], datetime.now(timezone.utc).strftime('%Y%m%d_%H-%M-%S')) os.makedirs(run_dir, exist_ok=True) shutil.copy(args.config, run_dir) if hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None: shutil.copy(args.deepspeed_config, run_dir) # wait for all processes then get the most recent dir (may have just been created) deepspeed.comm.barrier() run_dir = get_most_recent_run_dir(config['output_dir']) # Ugly hack so we can move quantized models from GPU to CPU, and back to GPU again without triggering quantization a second time. bnb_cuda_old = bitsandbytes.nn.modules.Params4bit.cuda def bnb_cuda_hijack(self, device): if getattr(self, 'already_quantized', False): self.data = self.data.to(device) self.quant_state.to(device) return self self.already_quantized = True return bnb_cuda_old(self, device) bitsandbytes.nn.modules.Params4bit.cuda = bnb_cuda_hijack pipeline_model, lora_model, lora_config = load_pipeline_model_with_lora(config, model_type) parameters_to_train = [p for p in pipeline_model.parameters() if p.requires_grad] optim_config = config['optimizer'] def get_optimizer(model_parameters): lr = optim_config['lr'] optim_type = optim_config['type'].lower() optimizer_kwargs = { 'params': model_parameters, 'lr': lr, 'betas': (optim_config.get('beta1', 0.9), optim_config.get('beta2', 0.99)), 'weight_decay': optim_config.get('weight_decay', 0.01), 'eps': optim_config.get('eps', 1e-6), } if optim_type == 'adamw': optimizer_cls = deepspeed.ops.adam.FusedAdam elif optim_type == 'adamw8bit': optimizer_cls = bitsandbytes.optim.AdamW8bit elif optim_type == 'adamw_kahan': import optimi optimizer_cls = optimi.AdamW optimizer_kwargs['kahan_sum'] = optim_config.get('kahan_sum', True) else: raise NotImplementedError(optim_type) if optim_config.get('use_loraplus', False): loraplus_lr_ratio = optim_config.get('loraplus_lr_ratio', 16) # TODO: handle params being thrown out here; why is it included in the first place? # delete 'params' from optimizer_kwargs del optimizer_kwargs['params'] return create_loraplus_optimizer( model=pipeline_model, optimizer_cls=optimizer_cls, loraplus_lr_ratio=loraplus_lr_ratio, **optimizer_kwargs, ) return optimizer_cls(**optimizer_kwargs) kwargs = {} if sampling_settings := config.get('sampling', None): for k, v in sampling_settings.items(): kwargs['sampling_' + k] = v rejected_sampling = config.get('rejected_sampling', False) rl_config=config.get('rl', None) model_engine, optimizer = engine.initialize( args=args, model=pipeline_model, model_parameters=parameters_to_train, optimizer=get_optimizer, lora_model=lora_model, config=ds_config, tokenizer=tokenizer, rl_config=rl_config, rejected_sampling=rejected_sampling, rejected_sampling_max_new_tokens=config.get('rejected_sampling_max_new_tokens', 1e9), **kwargs, ) # TODO: I have recently realized that we are setting things to fp16/bf16, even though all the DS # config was not in fp16 / bf16 mode. DS being in fp16/bf16 changes things in many places, e.g. # it can give you a BF16_Optimizer wrapper that accumulates grads in fp32, the communication dtype # is different, etc. I need to really look through all the implications of this. This change is so # that we keep the normal optimizer, but the communication dtype is changed so that we don't # unnecessarily cast grads to fp32. weight_dtype = DTYPE_MAP[config.get('lora_weight_dtype', config.get('model_weight_dtype', 'float32'))] model_engine.communication_data_type = weight_dtype # TODO: the main DeepSpeedEngine forces all parameters to the GPU, and also does things like # broadcast all parameters from data parallel rank 0 to all other ranks. Thus, MLP offloading # must come after engine.initialize(). If we want to avoid loading everything onto GPUs only # to offload the MLPs, we have to rewrite a lot of code to work around things. if config.get('offload_mlp_to_cpu', False): assert config['activation_checkpointing'] # MLP offloading only works with activation checkpointing for module in pipeline_model.modules(): if hasattr(module, 'move_mlp_to_cpu'): module.move_mlp_to_cpu() torch.cuda.empty_cache() train_dataloader = dataloader.PipelineDataLoader( train_data, tokenizer, model_engine.train_micro_batch_size_per_gpu(), model_engine.gradient_accumulation_steps(), model_engine.grid.get_data_parallel_world_size(), model_engine.grid.get_data_parallel_rank(), group_by_length=False if 'group_by_length' not in config else config['group_by_length'], batch_size_tokens=None if 'batch_size_tokens' not in config else config['batch_size_tokens'], return_dict=rejected_sampling, rl=(rl_config is not None), ) model_engine.set_dataloader(train_dataloader) steps_per_epoch = len(train_dataloader) // model_engine.gradient_accumulation_steps() model_engine.total_steps = steps_per_epoch * config['epochs'] if is_main_process(): # Warn if eval dataset is unusually large compared to the eval steps eval_data_length = sum([len(eval_data) for eval_data in eval_data_map.values()]) train_data_length = len(train_data) evals_per_epoch = steps_per_epoch / config['eval_steps'] relative_eval_time = evals_per_epoch * eval_data_length # train step very roughly 3 times slower due to backprop + usually activation checkpointing is enabled relative_train_time = train_data_length * 3 # Expect <=15% of our time spent evaluating vs training fraction_evaling = relative_eval_time / (relative_eval_time + relative_train_time) print() print( f'eval_data_length: {eval_data_length}, eval_steps: {config["eval_steps"]}; evals per epoch: {evals_per_epoch}. ' f'We will be spending approximately {fraction_evaling * 100:.2f}% of our time evaluating.' ) if fraction_evaling > 0.15: print( 'WARNING: eval dataset is unusually large compared to eval_steps. We will spend a lot of time evaluating. Lowering eval_size and/or bumping eval_steps is recommended.' ) print() # handle Deepspeed optimizer wrapper (e.g. BF16_Optimizer) optimizer = getattr(optimizer, 'optimizer', optimizer) warmup_steps = config.get('warmup_steps', 0) # Fractional values less than 1 are converted into "fraction of epoch" worth of steps if 0 < warmup_steps < 1: warmup_steps = int(warmup_steps * steps_per_epoch) if 'lr_scheduler' not in config or config['lr_scheduler'] == 'constant' or config['lr_scheduler'] == 'none': lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0) elif config['lr_scheduler'] == 'cosine': total_steps = steps_per_epoch * config['epochs'] total_steps -= warmup_steps lr_scheduler_kwargs = { 'optimizer': optimizer, 'T_max': total_steps, } if 'lr_min' in optim_config: lr_scheduler_kwargs['eta_min'] = optim_config['lr_min'] # Normally, you would pass the lr_scheduler to deepspeed.initialize(). But we need the # global batch_size in order to make the lr_scheduler. lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(**lr_scheduler_kwargs) else: raise NotImplementedError() load_optimizer_states = config.get('load_optimizer_states', True) # if resuming and not loading optimizer states, we can't use warmup or the LR never changes from the initial value (still don't know why) if warmup_steps > 0 and load_optimizer_states: warmup_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=1 / warmup_steps, total_iters=warmup_steps ) lr_scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_scheduler, lr_scheduler], milestones=[warmup_steps] ) model_engine.lr_scheduler = lr_scheduler step = 1 if resume_from_checkpoint: load_path, client_state = model_engine.load_checkpoint( run_dir, load_module_strict=False, load_lr_scheduler_states='force_constant_lr' not in config, load_optimizer_states=load_optimizer_states, ) deepspeed.comm.barrier() # just so the print below doesn't get swamped assert load_path is not None train_dataloader.load_state_dict(client_state['custom_loader']) step = client_state['step'] + 1 del client_state # if we skip loading the optimizer states, we need to step the LR scheduler so we start at the right value if not load_optimizer_states: model_engine.lr_scheduler.step() if is_main_process(): print(f'Resuming training from checkpoint. Resuming at epoch: {train_dataloader.epoch}, step: {step}') if 'force_constant_lr' in config: model_engine.lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0) for pg in optimizer.param_groups: pg['lr'] = config['force_constant_lr'] # this is a separate option, because if it's too high we might drop a significant fraction of the eval dataset eval_gradient_accumulation_steps = ( config['eval_gradient_accumulation_steps'] if 'eval_gradient_accumulation_steps' in config else 1 ) # Eval dataset doesn't need to repeat; we just use this to track "epoch" so we know when we're done iterating over it. eval_dataloaders = { name: dataloader.PipelineDataLoader( eval_data, tokenizer, model_engine.train_micro_batch_size_per_gpu(), eval_gradient_accumulation_steps, model_engine.grid.get_data_parallel_world_size(), model_engine.grid.get_data_parallel_rank(), shuffle=False, group_by_length=False if 'group_by_length' not in config else config['group_by_length'], batch_size_tokens=None if 'batch_size_tokens' not in config else config['batch_size_tokens'], return_dict=rejected_sampling, rl=(rl_config is not None), ) for name, eval_data in eval_data_map.items() } tb_writer = SummaryWriter(log_dir=run_dir) if is_main_process() else None epoch = train_dataloader.epoch saver = Saver(model_engine, pipeline_model, train_dataloader, lora_config, run_dir, args, config) epoch = train_dataloader.epoch if config.get('eval_before_first_step', False) and not resume_from_checkpoint: loss = evaluate(model_engine, eval_dataloaders, tb_writer, 0, eval_gradient_accumulation_steps) saver.append_eval_results(loss, save_best=False) while True: metrics = model_engine.train_batch() train_dataloader.sync_epoch() if lora_config is not None: keys_scaled, avg_norm, max_norm, norms = apply_max_norm_regularization(pipeline_model, config) new_epoch = saver.process_epoch(epoch, step) finished_epoch = True if new_epoch != epoch else False if is_main_process() and step % config['logging_steps'] == 0: write_metrics(tb_writer, 'train', metrics, step) tb_writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], step) # TODO: gather the weight norms across all stages in the pipelined model, not just the first. if lora_config is not None and len(norms) > 0: tb_writer.add_scalar('train/weights_scaled', keys_scaled, step) tb_writer.add_scalar('train/weight_norm_avg', avg_norm, step) tb_writer.add_scalar('train/weight_norm_max', max_norm, step) tb_writer.add_histogram('train/weight_norm_hist', norms, step) tb_writer.add_scalar('train/epoch', step / steps_per_epoch, step) if step % config['eval_steps'] == 0: loss = evaluate(model_engine, eval_dataloaders, tb_writer, step, eval_gradient_accumulation_steps) saver.append_eval_results(loss) saver.process_step(step) if finished_epoch: epoch = new_epoch if epoch is None: break step += 1 if ((step - 1) % config['eval_steps'] != 0) and config.get('eval_after_last_step', False): loss = evaluate(model_engine, eval_dataloaders, tb_writer, step - 1, eval_gradient_accumulation_steps) saver.append_eval_results(loss) if is_main_process(): print('TRAINING COMPLETE!') ================================================ FILE: utils/dataloader.py ================================================ import math import os.path import sys sys.path.insert(0, os.path.abspath('axolotl/src')) import accelerate import torch import transformers from deepspeed import comm as dist from torch.utils.data import DataLoader from axolotl.utils.collators import DataCollatorForSeq2Seq from utils.utils import is_main_process # A100 wants padding to multiple of 64, other cards are efficient with smaller, so just do 64 PAD_TO_MULTIPLE = 64 # Splits an example (feature dict) along the batch dimension into a list of examples. def split_batch(example, pieces): input_ids = example['input_ids'] if is_main_process(): print(f'before GAS splitting, input_ids shape: {input_ids.shape}, total tokens: {input_ids.numel()}') input_batch_size = input_ids.size(0) split_size = input_batch_size // pieces examples = [{} for _ in range(pieces)] for key, tensor in example.items(): assert tensor.size(0) == input_batch_size for i, tensor_slice in enumerate(torch.split(tensor, split_size)): examples[i][key] = tensor_slice return examples # Merge lists of examples a and b, such that for each contiguous piece in the result, the first half comes from # a and the second half from b. Used for DPO. The splitting must match how split_batch() does it. def combine_piecewise(a, b, pieces): assert len(a) == len(b) split_size = len(a) // pieces a_chunks = [a[i : i + split_size] for i in range(0, len(a), split_size)] b_chunks = [b[i : i + split_size] for i in range(0, len(b), split_size)] result = [] for a_chunk, b_chunk in zip(a_chunks, b_chunks): result.extend(a_chunk) result.extend(b_chunk) return result # Flattens a list of examples with batch dimension into a list of examples with no batch dimension. def flatten_examples(examples): result = [] for example in examples: batch_size = example['input_ids'].size(0) new_examples = [{} for _ in range(batch_size)] for key, tensor in example.items(): assert tensor.size(0) == batch_size for i, tensor_slice in enumerate(tensor): new_examples[i][key] = tensor_slice result.extend(new_examples) return result def example_to_tuple(example): return (example['input_ids'], example['attention_mask'], example['labels']), None def shuffle_list(l, seed): g = torch.Generator() g.manual_seed(seed) shuffle_idx = torch.randperm(len(l), generator=g).tolist() new_l = [l[i] for i in shuffle_idx] return new_l def batch_size_tokens_after_padding(batch): return max(math.ceil(pair[1] / PAD_TO_MULTIPLE) * PAD_TO_MULTIPLE for pair in batch) * len(batch) # A distributed batch sampler that supports grouping by length class DistributedBatchSamper(torch.utils.data.Sampler): def __init__( self, dataset, batch_size, num_replicas, rank, batch_size_multiplier=1, shuffle=True, group_by_length=False, seed=0, batch_size_tokens=None, ): self.dataset = dataset self.batch_size = batch_size self.batch_size_tokens = batch_size_tokens self.batch_size_multiplier = batch_size_multiplier self.num_replicas = num_replicas self.rank = rank # every global batch must be evenly divisible by this amount self.chunk_size = self.num_replicas * self.batch_size_multiplier self.shuffle = shuffle self.group_by_length = group_by_length self.seed = seed # Make list of (index, size). Sort or shuffle as needed. indices = list(enumerate(self.dataset['length'])) if self.group_by_length: indices.sort(key=lambda t: t[1]) elif self.shuffle: indices = shuffle_list(indices, self.seed) # Group indices together into global batches. global_batches = [] current_batch = [] for i in range(0, len(indices), self.chunk_size): slice = indices[i : i + self.chunk_size] if len(slice) < self.chunk_size: # pad with random examples if slice is too small padding_size = self.chunk_size - len(slice) shuffled_indices = shuffle_list(indices, self.seed + 1) if padding_size < len(shuffled_indices): slice += shuffled_indices[:padding_size] else: slice += (shuffled_indices * math.ceil(padding_size / len(shuffled_indices)))[:padding_size] if self.should_emit_current_batch(current_batch, slice): global_batches.append(current_batch) current_batch = [] current_batch.extend(slice) # Emit anything remaining if len(current_batch) > 0: global_batches.append(current_batch) if self.shuffle: global_batches = shuffle_list(global_batches, self.seed + 2) # make sure the largest batch comes first to OOM sooner rather than later largest_global_batch = 0 max_tokens = 0 for global_batch_idx, batch in enumerate(global_batches): total_batch_tokens = batch_size_tokens_after_padding(batch) if total_batch_tokens > max_tokens: max_tokens = total_batch_tokens largest_global_batch = global_batch_idx global_batches[0], global_batches[largest_global_batch] = ( global_batches[largest_global_batch], global_batches[0], ) batches_for_this_rank = [ global_batch[self.rank : len(global_batch) : self.num_replicas] for global_batch in global_batches ] self.indices = [[i for i, _ in batch] for batch in batches_for_this_rank] def should_emit_current_batch(self, current_batch, slice): if not self.batch_size_tokens: batch_size_after_appending = len(current_batch) // self.chunk_size + 1 if batch_size_after_appending > self.batch_size: return True else: return False else: global_batch_size_tokens = self.batch_size_tokens * self.chunk_size current_batch_tokens_after_appending = batch_size_tokens_after_padding(current_batch + slice) if len(current_batch) > 0 and current_batch_tokens_after_appending > global_batch_size_tokens: return True return False def __iter__(self): return iter(self.indices) def __len__(self): return len(self.indices) class PipelineDataLoader: def __init__( self, dataset, tokenizer, batch_size, gradient_accumulation_steps, data_parallel_world_size, data_parallel_rank, shuffle=True, group_by_length=False, pad_to_multiple_of=PAD_TO_MULTIPLE, batch_size_tokens=None, return_dict=False, rl=False, ): assert data_parallel_rank < data_parallel_world_size self.dataset = dataset self.tokenizer = tokenizer self.batch_size = batch_size self.batch_size_tokens = batch_size_tokens self.gradient_accumulation_steps = gradient_accumulation_steps self.pad_to_multiple_of = pad_to_multiple_of self.return_dict = return_dict self.rl = rl self.data_sampler = DistributedBatchSamper( dataset=dataset, batch_size=self.batch_size, batch_size_tokens=self.batch_size_tokens, batch_size_multiplier=self.gradient_accumulation_steps, num_replicas=data_parallel_world_size, rank=data_parallel_rank, shuffle=shuffle, group_by_length=group_by_length, ) data_collator = DataCollatorForSeq2Seq(self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) def collate_fn(examples, gradient_accumulation_steps=self.gradient_accumulation_steps): has_batch_dimension = examples[0]['input_ids'].ndim == 2 if has_batch_dimension: examples = flatten_examples(examples) rejected_examples = [] for example in examples: example.pop('length', None) example.pop('token_type_ids', None) rejected_example = {} for key in list(example.keys()): if 'rejected_' in key: x = example.pop(key) # Just drop the rejected_ entries if not doing RL. This allows normal SFT on just the # accepted completions of a RL dataset. if self.rl: rejected_example[key.strip('rejected_')] = x if rejected_example: rejected_examples.append(rejected_example) if rejected_examples: examples = combine_piecewise(examples, rejected_examples, gradient_accumulation_steps) return data_collator(examples) self.collate_fn = collate_fn self.epoch = 1 self.num_batches_pulled = 0 self.next_micro_batch = None self.recreate_dataloader = False self._create_dataloader() self.data = self._pull_batches_from_dataloader() def reset(self): self.epoch = 1 self.num_batches_pulled = 0 self.next_micro_batch = None self.data = self._pull_batches_from_dataloader() def __iter__(self): return self def __len__(self): return len(self.data_sampler) * self.gradient_accumulation_steps def __next__(self): if self.next_micro_batch is None: self.next_micro_batch = next(self.data) ret = self.next_micro_batch try: self.next_micro_batch = next(self.data) except StopIteration: if self.recreate_dataloader: self._create_dataloader() self.recreate_dataloader = False self.data = self._pull_batches_from_dataloader() self.num_batches_pulled = 0 self.next_micro_batch = next(self.data) self.epoch += 1 return ret def _pull_batches_from_dataloader(self): for batch in self.dataloader: self.num_batches_pulled += 1 for micro_batch in split_batch(batch, self.gradient_accumulation_steps): if self.return_dict: yield micro_batch else: # input to pipeline is (input_ids, attention_mask, labels) # this needs to return (features, labels) # it is OK if labels is None (the model just returns the loss anyway) yield example_to_tuple(micro_batch) def _create_dataloader(self): self.dataloader = DataLoader( self.dataset, pin_memory=True, batch_sampler=self.data_sampler, collate_fn=self.collate_fn, ) def state_dict(self): return { 'epoch': self.epoch, 'num_batches_pulled': self.num_batches_pulled, } def load_state_dict(self, state_dict): self.epoch = state_dict['epoch'] # -1 because by preloading the next micro_batch, it's always going to have one more batch # pulled than the actual number of batches iterated by the caller. self.num_batches_pulled = state_dict['num_batches_pulled'] - 1 self._create_dataloader() self.dataloader = accelerate.skip_first_batches(self.dataloader, self.num_batches_pulled) self.data = self._pull_batches_from_dataloader() # Recreate the dataloader after the first pass so that it won't skip # batches again (we only want it to skip batches the first time). self.recreate_dataloader = True # Only the first and last stages in the pipeline pull from the dataloader. Parts of the code need # to know the epoch, so we synchronize the epoch so the processes that don't use the dataloader # know the current epoch. def sync_epoch(self): process_group = dist.get_world_group() result = [None] * dist.get_world_size(process_group) torch.distributed.all_gather_object(result, self.epoch, group=process_group) max_epoch = -1 for epoch in result: max_epoch = max(epoch, max_epoch) self.epoch = max_epoch # for testing if __name__ == '__main__': tokenizer = transformers.AutoTokenizer.from_pretrained(sys.argv[1], local_files_only=True) tokenizer.pad_token_id = 1000 from datasets import Dataset data = [] for i in range(1, 41): input_ids = torch.tensor([i] * i) data.append( { 'input_ids': input_ids, 'attention_mask': torch.ones_like(input_ids), 'labels': input_ids, 'length': len(input_ids), } ) dataset = Dataset.from_list(data) # dataloader = PipelineDataLoader(dataset, tokenizer, batch_size=2, gradient_accumulation_steps=2, data_parallel_world_size=1, data_parallel_rank=0, group_by_length=True, pad_to_multiple_of=None) # for batch in dataloader: # if dataloader.epoch > 1: # break # print(batch) # print() batch_size = 2 gradient_accumulation_steps = 2 data_parallel_world_size = 2 data_parallel_rank = 0 dataloader = PipelineDataLoader( dataset, tokenizer, batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps, data_parallel_world_size=data_parallel_world_size, data_parallel_rank=data_parallel_rank, shuffle=False, group_by_length=False, pad_to_multiple_of=None, ) print(next(dataloader)[0][0]) print(next(dataloader)[0][0]) print(next(dataloader)[0][0]) print(next(dataloader)[0][0]) state_dict = dataloader.state_dict() dataloader = PipelineDataLoader( dataset, tokenizer, batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps, data_parallel_world_size=data_parallel_world_size, data_parallel_rank=data_parallel_rank, shuffle=False, group_by_length=False, pad_to_multiple_of=None, ) dataloader.load_state_dict(state_dict) print() print('-' * 80) print() print(next(dataloader)[0][0]) print(next(dataloader)[0][0]) print(next(dataloader)[0][0]) print(next(dataloader)[0][0]) ================================================ FILE: utils/dataset_utils.py ================================================ import os import os.path import sys sys.path.insert(0, os.path.abspath('axolotl/src')) import datasets import torch import yaml from tqdm import tqdm from axolotl.utils.data import prepare_dataset from axolotl.utils.dict import DictDefault from utils.utils import is_main_process, zero_first NUM_PROC = min(64, os.cpu_count()) def yield_sequences_from_token_batch(tokenizer, token_batch, sequence_len): # Initialize sequence_tokens with BOS token if it exists sequence_tokens = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else [] for tokens in tqdm(token_batch): tokens = tokens.tolist() assert len(tokens) > 0, 'Empty tokens list' if tokens[-1] != tokenizer.eos_token_id: tokens.append(tokenizer.eos_token_id) idx = 0 # Skip the auto-generated BOS token if present if tokenizer.bos_token_id is not None and tokens[0] == tokenizer.bos_token_id: idx += 1 while idx < len(tokens): # Calculate how many tokens are needed to fill the sequence need = sequence_len - len(sequence_tokens) taken = tokens[idx : idx + need] idx += len(taken) sequence_tokens.extend(taken) if len(sequence_tokens) >= sequence_len: assert len(sequence_tokens) == sequence_len yield sequence_tokens # Reset sequence_tokens with BOS token if it exists sequence_tokens = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else [] # yield anything remaining # TODO: disabled until I get training working with variable length sequences # if len(sequence_tokens) > 0: # yield sequence_tokens def slice_into_chunks(x, sequence_len, overlap=0): result = [] step = sequence_len - overlap for i in range(0, len(x), step): result.append(x[i : i + sequence_len]) return result def load_raw_dataset(dataset_path, tokenizer, sequence_len, eval_size, overlap=0, subsample_documents=None): if dataset_path.endswith('.txt'): dataset = datasets.load_dataset('text', data_files=dataset_path, sample_by='document')['train'] elif dataset_path.endswith('.json') or dataset_path.endswith('.jsonl'): dataset = datasets.load_dataset('json', data_files=dataset_path)['train'] else: raise NotImplementedError() dataset.set_format(type='torch') if subsample_documents: dataset = dataset.shuffle(seed=13).select(list(range(int(subsample_documents * len(dataset))))) dataset = dataset.map( lambda x: tokenizer(x['text']), batched=True, batch_size=10, remove_columns=dataset.column_names, desc='tokenizing', num_proc=NUM_PROC, ) # TODO: maybe do it this way instead # dataset = dataset.map(lambda x: {'tokens': slice_into_chunks(x['tokens'][0], sequence_len, overlap=overlap)}, batched=True, batch_size=1) dataset = dataset.map( lambda x: {'input_ids': list(yield_sequences_from_token_batch(tokenizer, x['input_ids'], sequence_len))}, batched=True, batch_size=None, remove_columns=dataset.column_names, desc='splitting', ) dataset = dataset.map( lambda x: {'attention_mask': torch.ones_like(x['input_ids']), 'labels': x['input_ids']}, desc='adding attention_mask and labels', ) if eval_size > 0: split_datasets = dataset.train_test_split(test_size=eval_size, shuffle=True, seed=42) train_data = split_datasets['train'] eval_data = split_datasets['test'] else: train_data = dataset eval_data = None return train_data, eval_data def load_axolotl_dataset(dataset_path, tokenizer, sequence_len, eval_size): with open(dataset_path) as f: cfg = yaml.safe_load(f.read()) if 'val_set_size' not in cfg: cfg['val_set_size'] = 0 if eval_size is None else eval_size cfg['sequence_len'] = sequence_len cfg['tokenizer_config'] = 'dummy' # these don't matter, but they have to be set cfg['batch_size'] = 1 cfg['num_epochs'] = 1 cfg['sequence_parallel_degree'] = 1 cfg = DictDefault(cfg) train_data, eval_data, *_ = prepare_dataset(cfg, tokenizer) train_data.set_format(type='torch') if eval_data is not None: eval_data.set_format(type='torch') return train_data, eval_data def load_pretokenized_dataset(dataset_path, tokenizer, sequence_len, eval_size): ds = datasets.load_from_disk(dataset_path) assert 'input_ids' in ds.column_names assert 'attention_mask' in ds.column_names assert 'labels' in ds.column_names ds = ds.filter(lambda example: len(example['input_ids']) <= sequence_len, desc='dropping long sequences', num_proc=NUM_PROC) if eval_size > 0: split_datasets = ds.train_test_split(test_size=eval_size, shuffle=True, seed=42) train_data = split_datasets['train'] eval_data = split_datasets['test'] else: train_data = ds eval_data = None return train_data, eval_data def load_single_dataset(dataset_config, tokenizer): dataset_path = dataset_config['dataset_path'] dataset_type = dataset_config['dataset_type'] sequence_len = dataset_config['sequence_len'] eval_size = dataset_config.get('eval_size', 0) subsample = dataset_config.get('subsample', None) num_repeats = dataset_config.get('num_repeats', None) if dataset_type in ['textfile', 'doclist']: with zero_first(is_main_process()): train_data, eval_data = load_raw_dataset(dataset_path, tokenizer, sequence_len, eval_size) elif dataset_type == 'axolotl': train_data, eval_data = load_axolotl_dataset(dataset_path, tokenizer, sequence_len, eval_size) elif dataset_type == 'pretokenized': with zero_first(is_main_process()): train_data, eval_data = load_pretokenized_dataset(dataset_path, tokenizer, sequence_len, eval_size) else: raise NotImplementedError() train_data = train_data.shuffle(seed=42) if eval_data is not None: eval_data = eval_data.shuffle(seed=42) if subsample is not None: assert 0 < subsample < 1 train_data = train_data.select(range(int(len(train_data) * subsample))) if eval_data is not None: eval_data = eval_data.select(range(int(len(eval_data) * subsample))) def add_length(x): length = len(x['input_ids']) if 'rejected_input_ids' in x: length = max(length, len(x['rejected_input_ids'])) return {'length': length} with zero_first(is_main_process()): train_data = train_data.map(add_length, desc='adding length field', num_proc=NUM_PROC) if eval_data is not None: eval_data = eval_data.map(add_length, desc='adding length field', num_proc=NUM_PROC) if 'prompt_attention_mask' in train_data.column_names: train_data = train_data.remove_columns('prompt_attention_mask') if eval_data is not None: eval_data = eval_data.remove_columns('prompt_attention_mask') if num_repeats: train_data = train_data.repeat(num_repeats) if is_main_process(): print(f'train_data size: {len(train_data)}') if eval_data is not None: print(f'eval_data size: {len(eval_data)}') return train_data, eval_data def combine_datasets(dataset_list, config, sample_weights): sample_weights = torch.tensor(sample_weights, dtype=torch.float32) mode = config.get('dataset_combination_mode', 'concatenate') if mode == 'concatenate': dataset = datasets.concatenate_datasets(dataset_list) elif mode == 'interleave': if 'batch_size_tokens' in config: # batches are formed so they have equal token counts, so interleave datasets based on token counts, not rows avg_lengths = torch.tensor( [dataset['length'].to(torch.float32).mean() for dataset in dataset_list], dtype=torch.float32 ) sample_weights = sample_weights / avg_lengths sample_weights = sample_weights.to( torch.float64 ) # float64 or interleave_datasets complains that probs don't sum to 1 probs = sample_weights / sample_weights.sum() dataset = datasets.interleave_datasets( dataset_list, probabilities=probs, seed=42, stopping_strategy=config.get('dataset_interleave_stopping_strategy', 'first_exhausted'), ) else: raise ValueError(mode) return dataset # TODO: reduce the extra unneeded left padding caused by accepted and rejected being collated # together. def process_dataset_for_rejected_sampling(dataset): def _rejected_sampling_map_fn(example): input_ids = example['input_ids'] attention_mask = example['attention_mask'] labels = example['labels'] assert input_ids.ndim == 1 assert attention_mask.ndim == 1 assert labels.ndim == 1 # Labels are -100 for masked tokens (the prompt). label_mask = (labels == -100).to(torch.int32) if (label_mask == 0).all(): # Raw text dataset # TODO: allow configuring this completion_start = len(label_mask) // 2 labels[:completion_start] = -100 else: # index of first False completion_start = torch.argmin(label_mask).item() return { 'labels': labels, 'rejected_input_ids': input_ids[:completion_start], 'rejected_attention_mask': attention_mask[:completion_start], 'rejected_labels': labels[:completion_start], } return dataset.map(_rejected_sampling_map_fn, desc='Processing dataset for negative sampling', num_proc=NUM_PROC) def load_datasets(config, tokenizer): if 'datasets' not in config: raise ValueError('Need to specify at least one dataset') train_datasets = [] sample_weights = [] eval_datasets = {} i = 0 for dataset_config in config['datasets']: if 'name' in dataset_config: name = dataset_config['name'] else: name = f'dataset{i}' i += 1 sample_weights.append(dataset_config.get('sample_weight', 1.0)) train, eval = load_single_dataset(dataset_config, tokenizer) train_datasets.append(train) if eval is not None: eval_datasets[name] = eval for dataset_config in config.get('eval_datasets', []): if 'name' in dataset_config: name = dataset_config['name'] else: name = f'dataset{i}' i += 1 eval, _ = load_single_dataset(dataset_config, tokenizer) eval_datasets[name] = eval if len(train_datasets) == 1: train_dataset = train_datasets[0] else: with zero_first(is_main_process()): train_dataset = combine_datasets(train_datasets, config, sample_weights=sample_weights) if config.get('rejected_sampling', False): assert 'rl' in config train_dataset = process_dataset_for_rejected_sampling(train_dataset) eval_datasets = {name: process_dataset_for_rejected_sampling(ds) for name, ds in eval_datasets.items()} if 'rl' in config: assert 'rejected_input_ids' in train_dataset.column_names for eval_dataset in eval_datasets.values(): assert 'rejected_input_ids' in eval_dataset.column_names train_dataset.set_format(type='torch') for eval_dataset in eval_datasets.values(): eval_dataset.set_format(type='torch') return train_dataset, eval_datasets # for testing if __name__ == '__main__': import transformers # from datasets import disable_caching # disable_caching() tokenizer = transformers.AutoTokenizer.from_pretrained( sys.argv[1], local_files_only=True, use_fast=False, legacy=True ) tokenizer.pad_token_id = 0 tokenizer.padding_side = 'right' train_data1, eval_data1 = load_raw_dataset('/home/anon/data/test/txt/*.txt', tokenizer, 100, 0.5) train_data2, eval_data2 = load_raw_dataset('/home/anon/data/test/json/*.jsonl', tokenizer, 100, 0.5) print(len(train_data1)) print(len(train_data2)) ================================================ FILE: utils/engine.py ================================================ from collections import deque import time import deepspeed import torch import transformers from deepspeed import comm as dist from deepspeed.accelerator import get_accelerator from deepspeed.runtime import utils as ds_utils from deepspeed.runtime.activation_checkpointing import checkpointing as ds_checkpointing from deepspeed.runtime.config import DeepSpeedConfig from deepspeed.runtime.pipe import p2p, schedule from deepspeed.runtime.pipe.engine import ( BATCH_INPUT_TIMER, PIPE_RECV_GRAD_TIMER, PIPE_RECV_INPUT_TIMER, PIPE_SEND_GRAD_TIMER, PIPE_SEND_OUTPUT_TIMER, TRAIN_BATCH_TIMER, PipelineEngine, ) from deepspeed.runtime.pipe.module import LayerSpec, PipelineModule from deepspeed.runtime.pipe.schedule import ( BackwardPass, BufferOpInstruction, ForwardPass, OptimizerStep, PipeInstruction, PipeSchedule, RecvActivation, RecvGrad, ReduceGrads, ReduceTiedGrads, SendActivation, SendGrad, _is_even, _is_odd, ) from deepspeed.runtime.pipe.topology import ProcessTopology from deepspeed.runtime.utils import PartitionedTensor from torch import nn from utils.utils import eta_str, log, is_main_process from utils.dataloader import split_batch, example_to_tuple def initialize( args=None, model=None, config=None, **kwargs ): assert model is not None, 'deepspeed.initialize requires a model' dist_backend = get_accelerator().communication_backend_name() dist.init_distributed(dist_backend=dist_backend) if hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None: config = args.deepspeed_config mpu = model.mpu() config_class = DeepSpeedConfig(config, mpu) engine = CustomPipelineEngine( args=args, model=model, mpu=mpu, config=config, config_class=config_class, **kwargs, ) return engine, engine.optimizer def unpack_accepted_rejected(example): batch_size = example['input_ids'].size(0) half = batch_size // 2 for key, tensor in list(example.items()): assert tensor.size(0) == batch_size example[key] = tensor[:half] example['rejected_'+key] = tensor[half:] return example class LoadMicroBatchMultipleBuffers(PipeInstruction): def __init__(self, *buffer_ids, **kwargs): super().__init__(buffer_ids=buffer_ids, **kwargs) class ReferenceLogitsForwardPass(BufferOpInstruction): pass class CustomPipelineEngine(PipelineEngine): def __init__( self, *args, lora_model=None, tokenizer=None, rl_config=None, rejected_sampling=False, rejected_sampling_max_new_tokens=1e9, sampling_temperature=1.0, sampling_min_p=0., sampling_temperature_last=False, **kwargs ): super().__init__(*args, **kwargs) self.total_steps = None self.etas = deque() self.rl_config = {} # Assign list to avoid registering the nn.Module self.lora_model = [lora_model] self.tokenizer = tokenizer self.rl_config = rl_config self.rejected_sampling = rejected_sampling self.rejected_sampling_max_new_tokens = rejected_sampling_max_new_tokens eos_token_ids = set() if self.tokenizer is not None and self.tokenizer.eos_token_id is not None: eos_token_ids.add(self.tokenizer.eos_token_id) model_config = self.module.model.config if model_config.eos_token_id: model_eos_token_ids = model_config.eos_token_id if isinstance(model_eos_token_ids, int): model_eos_token_ids = [model_eos_token_ids] eos_token_ids.update(model_eos_token_ids) self.eos_token_ids = eos_token_ids # Sampling configuration. Only supports logits processors that don't use input_ids. self.logits_processor = transformers.LogitsProcessorList() temp = transformers.TemperatureLogitsWarper(float(sampling_temperature)) if sampling_min_p > 0: self.logits_processor.append(transformers.MinPLogitsWarper(float(sampling_min_p))) if sampling_temperature_last: self.logits_processor.append(temp) else: self.logits_processor.insert(0, temp) def set_dataloader(self, loader): self.collate_fn = loader.collate_fn if self.is_first_stage() or self.is_last_stage(): self.training_dataloader = loader self.data_iterator = iter(self.training_dataloader) def train_batch(self): if not torch._C.is_grad_enabled(): raise RuntimeError(f'train_batch() requires gradients enabled. Use eval_batch() instead.') self.timers(TRAIN_BATCH_TIMER).start() # sequence length may change between macro batches (but not between gradient accumulation steps) self.reset_activation_shape() train_iterator = self.data_iterator # Negative sampling if self.rejected_sampling: assert self.rl_config self.module.eval() model_inputs = self._sample_from_iterator(train_iterator, self.collate_fn) dist.barrier() if model_inputs is not None: self.set_dataiterator(iter(model_inputs)) self.reset_activation_shape() self.module.train() self._compute_loss = True # Do the work if self.rl_config: method = self.rl_config.get('method', None) if method == 'dpo': sched = DPOTrainSchedule(micro_batches=self.micro_batches, stages=self.num_stages, stage_id=self.stage_id) else: raise NotImplementedError(method) else: sched = schedule.TrainSchedule(micro_batches=self.micro_batches, stages=self.num_stages, stage_id=self.stage_id) self._exec_schedule(sched) agg_losses = self._aggregate_total_losses() # Actual training loss is always the first item. self.agg_train_loss = agg_losses[0].mean() self.timers(TRAIN_BATCH_TIMER).stop() if self.global_steps % self.steps_per_print() == 0: if self.global_rank == 0: elapsed = self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True) / 1000.0 iter_time = elapsed / self.steps_per_print() eta = iter_time * (self.total_steps - self.global_steps) self.etas.append(eta) while len(self.etas) > 10: self.etas.popleft() rolling_eta = sum(self.etas) / len(self.etas) tput = self.train_batch_size() / iter_time log( f'step: {self.global_steps:>5} / {self.total_steps:>5} ' f'loss: {self.agg_train_loss:0.4f} ' f'iter time (s): {iter_time:0.3f} ' f'samples/sec: {tput:0.3f} ' f'eta: {eta_str(rolling_eta)} ' ) else: self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True) # Monitoring if self.global_rank == 0 and self.monitor.enabled: self.summary_events = [ ('Train/Samples/train_loss', self.agg_train_loss.mean().item(), self.global_samples) ] self.monitor.write_events(self.summary_events) if self.wall_clock_breakdown() and self.global_steps % self.steps_per_print() == 0: self.timers.log( [ PIPE_SEND_OUTPUT_TIMER, PIPE_SEND_GRAD_TIMER, PIPE_RECV_INPUT_TIMER, PIPE_RECV_GRAD_TIMER, ] ) # Restore the training iterator self.set_dataiterator(train_iterator) return agg_losses def eval_batch(self, data_iter): # sequence length may change between macro batches (but not between gradient accumulation steps) self.reset_activation_shape() self.module.eval() self._compute_loss = True # Use the provided data iterator train_iterator = self.data_iterator self.set_dataiterator(data_iter) # Negative sampling if self.rejected_sampling: assert self.rl_config dist.barrier() model_inputs = self._sample_from_iterator(data_iter, self.collate_fn) if model_inputs is not None: self.set_dataiterator(iter(model_inputs)) self.reset_activation_shape() # Do the work if self.rl_config: method = self.rl_config.get('method', None) if method == 'dpo': sched = DPOInferenceSchedule(micro_batches=self.micro_batches, stages=self.num_stages, stage_id=self.stage_id) else: raise NotImplementedError(method) else: sched = schedule.InferenceSchedule( micro_batches=self.micro_batches, stages=self.num_stages, stage_id=self.stage_id ) # prevent dead-lock with multiple evals sequence dist.barrier() with torch.no_grad(): self._exec_schedule(sched) # list of losses agg_eval_losses = self._aggregate_total_losses() if self.global_rank == 0 and self.monitor.enabled: self.summary_events = [('Train/Samples/eval_loss', agg_eval_losses[0].mean().item(), self.global_samples)] self.monitor.write_events(self.summary_events) # Restore the training iterator self.set_dataiterator(train_iterator) return agg_eval_losses def sample_batch(self, prompts): assert isinstance(prompts, (list, tuple)) self.reset_activation_shape() self.module.eval() self.module.set_sampling_mode(True) original_micro_batches = self.micro_batches self.micro_batches = len(prompts) dist.barrier() if self.is_first_stage(): # Tokenizer returns dict with 'input_ids', 'attention_mask' keys. # Tensors have batch dimension because we pass list of prompts. examples = [] for prompt in prompts: if not isinstance(prompt, (list, tuple)): prompt = [prompt] examples.append(self.tokenizer(prompt, return_tensors='pt', padding=True)) else: examples = None with torch.no_grad(): examples = self._exec_sampling_schedule(examples) if self.is_first_stage(): text = [self.tokenizer.batch_decode(example['input_ids']) for example in examples] else: text = None self.micro_batches = original_micro_batches self.module.set_sampling_mode(False) return text def _sample_from_iterator(self, data_iter, collate_fn): if data_iter is not None: examples = [unpack_accepted_rejected(next(data_iter)) for _ in range(self.micro_batches)] # TODO: allow configuring this max_total_tokens = max(example['input_ids'].size(1) for example in examples) * 2 else: examples = None # This is okay, max_total_tokens is only checked on the first stage. max_total_tokens = None self.module.eval() self.module.set_sampling_mode(True) dist.barrier() with torch.no_grad(): examples = self._exec_sampling_schedule(examples, feature_prefix='rejected_', max_total_tokens=max_total_tokens) if is_main_process(): input_ids = examples[0]['rejected_input_ids'][0] attention_mask = examples[0]['rejected_attention_mask'][0] start = torch.argmax(attention_mask) end = len(attention_mask) - torch.argmax(torch.flip(attention_mask, (0,))) text = self.tokenizer.decode(input_ids[start:end]) print(f'Example of sampled rejected completion:\n{text}') self.module.set_sampling_mode(False) if examples is not None: batch = collate_fn(examples, gradient_accumulation_steps=self.micro_batches) model_inputs = [example_to_tuple(micro_batch) for micro_batch in split_batch(batch, self.micro_batches)] return model_inputs else: return None def _aggregate_total_losses(self): all_agg_outputs = [] # gather each output for all the gradient accumulation steps grouped_outputs = [list(x) for x in zip(*self.fwd_outputs)] # if any are scalar, make them dim 1 so we can concat across DP ranks for outputs in grouped_outputs: for i, output in enumerate(outputs): if output.dim() == 0: outputs[i] = torch.unsqueeze(output, 0) if self.is_last_stage(): agg_sizes = [] # loop to gather all the outputs across DP ranks for outputs in grouped_outputs: # concat all the grad_accum_steps concat_outputs = torch.cat(outputs) if self.is_data_parallel: # might be different sizes across DP ranks, so, gather all the sizes sizes = [None] * self.grid.get_data_parallel_world_size() torch.distributed.all_gather_object( sizes, concat_outputs.size(), group=self.grid.get_data_parallel_group() ) # once we know all the sizes we can gather the results across DP ranks gather_result = [torch.zeros(size).to(self.device) for size in sizes] dist.all_gather(gather_result, concat_outputs, group=self.grid.get_data_parallel_group()) # and finally, concat agg_output = torch.cat(gather_result) else: agg_output = concat_outputs agg_sizes.append(agg_output.size()) all_agg_outputs.append(agg_output) # send the sizes, then broadcast to the PP ranks if self.is_pipe_parallel: torch.distributed.broadcast_object_list( [agg_sizes], src=self.global_rank, group=self.grid.get_pipe_parallel_group() ) for agg_output in all_agg_outputs: dist.broadcast(tensor=agg_output, src=self.global_rank, group=self.grid.get_pipe_parallel_group()) else: # get the outputs from the last stage src_rank = self.grid.stage_to_global(self.num_stages - 1) assert src_rank in self.grid.pp_group result = [None] torch.distributed.broadcast_object_list(result, src=src_rank, group=self.grid.get_pipe_parallel_group()) agg_sizes = result[0] for agg_size in agg_sizes: agg_output = torch.zeros(agg_size).to(self.device) dist.broadcast(tensor=agg_output, src=src_rank, group=self.grid.get_pipe_parallel_group()) all_agg_outputs.append(agg_output) return all_agg_outputs # We override this to handle the model returning a list of "losses", but only doing backprop on the first. def _exec_forward_pass(self, buffer_id): self.tput_timer.start() self.mem_status('BEFORE FWD', reset_max=True) if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple): inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id]) else: inputs = self.pipe_buffers['inputs'][buffer_id].clone() # collect the partitioned input from the previous stage if self.is_pipe_partitioned and not self.is_first_stage(): part_input = PartitionedTensor.from_meta( meta=inputs[0], local_part=inputs[1], group=self.grid.get_slice_parallel_group() ) inputs = (part_input.full(), *inputs[2:]) inputs[0].requires_grad = True # skip mask # inputs[1].requires_grad = True part_input = None inputs = inputs[0] if len(inputs) == 1 else inputs self.pipe_buffers['inputs'][buffer_id] = inputs # inputs has no gradient because it is from a cloned tensor outputs = super(PipelineEngine, self).forward(inputs) # Reset activation checkpointing buffers. # Need to call this between evaluation iterations if not self.module.training: ds_checkpointing.reset() # Partition the outputs if we are not the last stage if self.is_pipe_partitioned and not self.is_last_stage(): if isinstance(outputs, tuple): first_output = outputs[0] # TODO: Improve pipe partitioning to pass multiple tensors that require grads assert all(torch.is_tensor(elt) and elt.requires_grad is False for elt in outputs[1:]) outputs_tail = outputs[1:] elif torch.is_tensor(outputs): first_output = outputs outputs_tail = [] else: raise ValueError('expecting a tensor or a tuple of tensors') part = PartitionedTensor(tensor=first_output, group=self.grid.get_slice_parallel_group()) # Clear the large output data, but save the computation graph first_output.data = torch.zeros(1) self.pipe_buffers['output_tensors'][buffer_id] = first_output # Inject the partitioned tensor into the output before sending outputs = (part.to_meta(), part.data(), *outputs_tail) part = None self.pipe_buffers['outputs'][buffer_id] = outputs # Optionally compute loss on the last device if self.is_last_stage(): if self._compute_loss and self.module.loss_fn is not None: labels = self.pipe_buffers['labels'][buffer_id] losses = self.module.loss_fn(outputs, labels) else: # Some models just return loss from forward() losses = outputs if self.eval_return_logits: self.outputs = outputs if isinstance(losses, torch.Tensor): self.loss = losses self.fwd_outputs.append([self.loss.detach()]) else: self.loss = losses[0] self.fwd_outputs.append([l.detach() for l in losses]) def _exec_load_micro_batch_multiple_buffers(self, buffer_ids): if self.wall_clock_breakdown(): self.timers(BATCH_INPUT_TIMER).start() batch = self._next_batch() if self.is_first_stage(): loaded = None if torch.is_tensor(batch[0]): loaded = batch[0].clone().to(self.device).detach() if ( self._config.pipeline['activation_checkpoint_interval'] > 0 and self._config.pipeline['use_reentrant'] ): loaded.requires_grad = loaded.is_floating_point() else: assert isinstance(batch[0], (tuple, list)) # Assume list or tuple loaded = [] for x in batch[0]: assert torch.is_tensor(x) mine = x.clone().detach().to(self.device) if ( self._config.pipeline['activation_checkpoint_interval'] > 0 and self._config.pipeline['use_reentrant'] ): mine.requires_grad = mine.is_floating_point() loaded.append(mine) loaded = tuple(loaded) for buffer_id in buffer_ids: self.pipe_buffers['inputs'][buffer_id] = loaded if self.is_last_stage(): loaded = batch[1] if torch.is_tensor(batch[1]): loaded = batch[1].to(self.device) # XXX: torch 1.6.0 DataLoader will auto convert tuple to list elif isinstance(batch[1], (tuple, list)): loaded = [] for x in batch[1]: assert torch.is_tensor(x) x = x.to(self.device).detach() loaded.append(x) loaded = tuple(loaded) for buffer_id in buffer_ids: self.pipe_buffers['labels'][buffer_id] = loaded if self.wall_clock_breakdown(): self.timers(BATCH_INPUT_TIMER).stop() @torch.no_grad() def _exec_reference_logits_forward_pass(self, buffer_id): self.lora_model[0].disable_adapter_layers() self.module.set_dpo_reference_mode(True) if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple): inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id]) else: inputs = self.pipe_buffers['inputs'][buffer_id].clone() # collect the partitioned input from the previous stage if self.is_pipe_partitioned and not self.is_first_stage(): if self.pipe_partition_input_meta_cache is None: self.pipe_partition_input_meta_cache = inputs[0].to('cpu') part_input = PartitionedTensor.from_meta( meta=self.pipe_partition_input_meta_cache, local_part=inputs[1], group=self.grid.get_slice_parallel_group(), ) inputs = (part_input.full(), *inputs[2:]) inputs[0].requires_grad = True # skip mask # inputs[1].requires_grad = True part_input = None inputs = inputs[0] if len(inputs) == 1 else inputs self.pipe_buffers['inputs'][buffer_id] = inputs # inputs has no gradient because it is from a cloned tensor outputs = super(PipelineEngine, self).forward(inputs) # Reset activation checkpointing buffers. # Need to call this between evaluation iterations if not self.module.training: ds_checkpointing.reset() # Partition the outputs if we are not the last stage if self.is_pipe_partitioned and not self.is_last_stage(): if isinstance(outputs, tuple): first_output = outputs[0] # TODO: Improve pipe partitioning to pass multiple tensors that require grads assert all(torch.is_tensor(elt) and elt.requires_grad is False for elt in outputs[1:]) outputs_tail = outputs[1:] elif torch.is_tensor(outputs): first_output = outputs outputs_tail = [] else: raise ValueError('expecting a tensor or a tuple of tensors') part = PartitionedTensor(tensor=first_output, group=self.grid.get_slice_parallel_group()) # Clear the large output data, but save the computation graph first_output.data = torch.zeros(1, device=first_output.data.device) self.pipe_buffers['output_tensors'][buffer_id] = first_output # Inject the partitioned tensor into the output before sending outputs = (part.to_meta(), part.data(), *outputs_tail) part = None self.pipe_buffers['outputs'][buffer_id] = outputs self.lora_model[0].enable_adapter_layers() self.module.set_dpo_reference_mode(False) def _exec_send_micro_batch_id(self, send_micro_batch_id): assert isinstance(send_micro_batch_id, int) if self.num_stages == 1: return send_micro_batch_id send_micro_batch_id = torch.tensor(send_micro_batch_id, device=self.device) recv_micro_batch_id = torch.tensor(-1, device=self.device) if _is_even(self.stage_id): if not self.is_last_stage(): p2p.send(send_micro_batch_id, self.next_stage) if not self.is_first_stage(): p2p.recv(recv_micro_batch_id, self.prev_stage) else: if not self.is_first_stage(): p2p.recv(recv_micro_batch_id, self.prev_stage) if not self.is_last_stage(): p2p.send(send_micro_batch_id, self.next_stage) # last stage sends to first stage if self.is_first_stage(): p2p.recv(recv_micro_batch_id, self.num_stages - 1) if self.is_last_stage(): p2p.send(send_micro_batch_id, 0) return recv_micro_batch_id.item() def _exec_load_micro_batch_for_sampling(self, buffer_id, inputs): loaded = ( inputs['input_ids'], inputs['attention_mask'], torch.tensor([0]), # labels must be provided, so use a dummy ) loaded = tuple(x.clone().detach().to(self.device) for x in loaded) self.pipe_buffers['inputs'][buffer_id] = loaded @torch.no_grad() def _exec_sampling_forward_pass(self, buffer_id): if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple): inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id]) else: inputs = self.pipe_buffers['inputs'][buffer_id].clone() # collect the partitioned input from the previous stage if self.is_pipe_partitioned and not self.is_first_stage(): if self.pipe_partition_input_meta_cache is None: self.pipe_partition_input_meta_cache = inputs[0].to('cpu') part_input = PartitionedTensor.from_meta( meta=self.pipe_partition_input_meta_cache, local_part=inputs[1], group=self.grid.get_slice_parallel_group(), ) inputs = (part_input.full(), *inputs[2:]) inputs[0].requires_grad = True # skip mask # inputs[1].requires_grad = True part_input = None inputs = inputs[0] if len(inputs) == 1 else inputs self.pipe_buffers['inputs'][buffer_id] = inputs # inputs has no gradient because it is from a cloned tensor outputs = super(PipelineEngine, self).forward(inputs) # Reset activation checkpointing buffers. # Need to call this between evaluation iterations if not self.module.training: ds_checkpointing.reset() # Partition the outputs if we are not the last stage if self.is_pipe_partitioned and not self.is_last_stage(): if isinstance(outputs, tuple): first_output = outputs[0] # TODO: Improve pipe partitioning to pass multiple tensors that require grads assert all(torch.is_tensor(elt) and elt.requires_grad is False for elt in outputs[1:]) outputs_tail = outputs[1:] elif torch.is_tensor(outputs): first_output = outputs outputs_tail = [] else: raise ValueError('expecting a tensor or a tuple of tensors') part = PartitionedTensor(tensor=first_output, group=self.grid.get_slice_parallel_group()) # Clear the large output data, but save the computation graph first_output.data = torch.zeros(1, device=first_output.data.device) self.pipe_buffers['output_tensors'][buffer_id] = first_output # Inject the partitioned tensor into the output before sending outputs = (part.to_meta(), part.data(), *outputs_tail) part = None self.pipe_buffers['outputs'][buffer_id] = outputs def _sample_from_logits(self, buffer_id): logits = self.pipe_buffers['outputs'][buffer_id].squeeze(1) logits = self.logits_processor(None, logits) probs = torch.nn.functional.softmax(logits, dim=-1) input_ids = torch.multinomial(probs, num_samples=1) # Logically you would squeeze(1) to remove the multinomial num_samples dimension, then # unsqueeze(1) to add back the sequence_length dimension. But those just cancel out. return input_ids def _valid_stage(self, stage_id): return 0 <= stage_id < self.num_stages def _valid_micro_batch(self, micro_batch_id): return 0 <= micro_batch_id < self.micro_batches def _exec_sampling_schedule(self, examples, feature_prefix='', max_total_tokens=1e9): start = time.time() input_ids_key = f'{feature_prefix}input_ids' attention_mask_key = f'{feature_prefix}attention_mask' labels_key = f'{feature_prefix}labels' # Reserve and reset buffers. self._reserve_pipe_buffers(2) self.fwd_outputs = [] eos_token_ids = torch.tensor(list(self.eos_token_ids)) finished = False if self.is_first_stage(): num_batches_done = 0 num_batches = len(examples) queue = deque() for i, example in enumerate(examples): example['done'] = torch.tensor([False]*example[input_ids_key].size(0)) example['num_new_tokens'] = 0 queue.append((i, { 'input_ids': example[input_ids_key], 'attention_mask': example[attention_mask_key], })) step_id = 0 micro_batch_id = -1 prev_micro_batch_id = -1 while not finished: # Alternate send/recv buffers if _is_even(self.stage_id): recv_buf = step_id % 2 send_buf = (step_id + 1) % 2 else: recv_buf = (step_id + 1) % 2 send_buf = step_id % 2 # Load from the queue on the first stage. if self.is_first_stage(): if len(queue) > 0: micro_batch_id, inputs = queue.popleft() self._exec_load_micro_batch_for_sampling(recv_buf, inputs) else: micro_batch_id = -1 # Send / receive activations if needed. if _is_even(self.stage_id): if self._valid_stage(self.next_stage): if self._valid_micro_batch(prev_micro_batch_id): self._exec_send_activations(send_buf) if self._valid_stage(self.prev_stage): if self._valid_micro_batch(micro_batch_id): self._exec_recv_activations(recv_buf) else: if self._valid_stage(self.prev_stage): if self._valid_micro_batch(micro_batch_id): self._exec_recv_activations(recv_buf) if self._valid_stage(self.next_stage): if self._valid_micro_batch(prev_micro_batch_id): self._exec_send_activations(send_buf) # Send micro_batch_id to next stage. Last stage wraps around and sends to first stage. prev_micro_batch_id = micro_batch_id micro_batch_id = self._exec_send_micro_batch_id(micro_batch_id) # Run forward(). # Note that prev_micro_batch_id is actually the micro_batch_id of the current step. if self._valid_micro_batch(prev_micro_batch_id): self.model.set_cache(prev_micro_batch_id) self._exec_sampling_forward_pass(recv_buf) if self.is_last_stage(): input_ids = self._sample_from_logits(recv_buf) if self.num_stages > 1: p2p.send(input_ids, 0) # First stage got a valid micro_batch_id from the last stage. Receive the input_ids and process them. if self.is_first_stage() and self._valid_micro_batch(micro_batch_id): example = examples[micro_batch_id] batch_size = example[input_ids_key].size(0) if self.num_stages > 1: input_ids = torch.full((batch_size, 1), -1, device=self.device) p2p.recv(input_ids, self.num_stages - 1) assert input_ids.size(-1) == 1, input_ids.shape if example['num_new_tokens'] >= self.rejected_sampling_max_new_tokens: finished = True if example[input_ids_key].size(1) >= max_total_tokens: finished = True if not finished: input_ids = input_ids.to('cpu') prev_done = example['done'] # Determine which items in the batch are done generating. done = prev_done | (input_ids == eos_token_ids).any(-1) example['done'] = done batch_done = done.all().item() # Output pad token and 0 attention mask for items in the batch that are already done. prev_done_reshaped = prev_done.unsqueeze(-1) input_ids = torch.where(prev_done_reshaped, self.tokenizer.pad_token_id, input_ids) attention_mask_extension = torch.where(prev_done_reshaped, 0, 1) labels_extention = torch.where(prev_done_reshaped, -100, input_ids) input_ids = torch.cat([example[input_ids_key], input_ids], dim=-1) example[input_ids_key] = input_ids if labels_key in example: example[labels_key] = torch.cat([example[labels_key], labels_extention], dim=-1) attention_mask = torch.cat([example[attention_mask_key], attention_mask_extension], dim=-1) example[attention_mask_key] = attention_mask example['num_new_tokens'] += 1 if batch_done: num_batches_done += 1 finished = (num_batches_done == num_batches) else: # Model needs full attention mask, but only most recent sampled input_id. queue.append( ( micro_batch_id, { 'input_ids': input_ids[..., -1:], 'attention_mask': attention_mask, }, ) ) # Broadcast finished from first stage to all other stages so they can exit the loop. src_rank = self.grid.stage_to_global(0) finished = [finished] if self.is_first_stage() else [None] torch.distributed.broadcast_object_list( finished, src=src_rank, group=self.grid.get_pipe_parallel_group() ) finished = finished[0] step_id += 1 # end while loop if self.is_first_stage(): total_new_tokens = 0 for example in examples: total_new_tokens += example['num_new_tokens'] * batch_size del example['done'] del example['num_new_tokens'] if is_main_process(): duration = time.time() - start tps = total_new_tokens / duration print(f'Total sampling time: {duration:.1f}, average tok/s: {tps:.1f}') dist.barrier() return examples # make our forward pass method apply PipelineEngine._INSTRUCTION_MAP[schedule.ForwardPass] = _exec_forward_pass PipelineEngine._INSTRUCTION_MAP[LoadMicroBatchMultipleBuffers] = _exec_load_micro_batch_multiple_buffers PipelineEngine._INSTRUCTION_MAP[ReferenceLogitsForwardPass] = _exec_reference_logits_forward_pass class ColumnMajorParallelTopology(ProcessTopology): """ A topology specialisation for hybrid data+pipeline parallelism optimized for LoRA training: - Sends high-volume "per token" hidden states over PCIe/NVLink. - Sends lower-volume "per step" LoRA gradient reductions over Ethernet/InfiniBand. """ def __init__(self, num_pp, num_dp): # Swap the axes and dims to change the rank mapping super().__init__(axes=['data', 'pipe'], dims=[num_dp, num_pp]) class CustomPipelineModule(PipelineModule): def __init__(self, layers, use_column_major_topology, model=None, **kwargs): # Assign to list to avoid registering the nn.Module self._model = [model] # Hybrid LoRA data+pipeline parallelism may want to use "column-major" layout if use_column_major_topology: world_size = dist.get_world_size() num_stages = kwargs.get('num_stages') if num_stages > 1 and world_size > 1: assert world_size % num_stages == 0, ( f'world_size ({world_size}) must be divisible by num_stages ({num_stages})' ) num_dp = world_size // num_stages kwargs['topology'] = ColumnMajorParallelTopology(num_pp=num_stages, num_dp=num_dp) super().__init__(layers, **kwargs) @property def model(self): return self._model[0] def set_dpo_reference_mode(self, dpo_reference_mode): self.model.set_dpo_reference_mode(dpo_reference_mode) def set_sampling_mode(self, sampling_mode): self.model.set_sampling_mode(sampling_mode) def _partition_layers(self, method='uniform'): num_stages = self._topo.get_dim('pipe') stage_id = self._topo.get_coord(self.global_rank).pipe if self.global_rank == 0: print(f'Partitioning pipeline stages with method {method}') method = method.lower() estimated_sizes = None # Each stage gets a simple uniform number of layers. if method == 'uniform': num_layers = len(self._layer_specs) self.parts = ds_utils.partition_uniform(num_items=num_layers, num_parts=num_stages) elif method == 'parameters': param_counts = self._count_layer_params() self.parts = ds_utils.partition_balanced(weights=param_counts, num_parts=num_stages) elif method.startswith('type:'): layertype = method.split(':')[1] binary_weights = [0] * len(self._layer_specs) for idx in self._find_layer_type(layertype): binary_weights[idx] = 1 self.parts = ds_utils.partition_balanced(weights=binary_weights, num_parts=num_stages) elif method == 'profile': raise NotImplementedError(f'Partitioning method {method} not implemented.') elif method == 'estimated_size': estimated_sizes = [getattr(l, 'estimated_size', 0) for l in self._layer_specs] self.parts = ds_utils.partition_balanced(weights=estimated_sizes, num_parts=num_stages) else: raise NotImplementedError(f'Partitioning method {method} not implemented.') # Print some information on the partitioning. if self.global_rank == 0: for stage in range(num_stages): start = self.parts[stage] stop = self.parts[stage + 1] print(f'stage={stage} layers={stop - start}') for idx, layer in enumerate(self._layer_specs[start:stop]): name = str(layer) if isinstance(layer, LayerSpec): name = layer.typename.__name__ if isinstance(layer, nn.Module): name = layer.__class__.__name__ else: try: name = layer.__name__ except AttributeError: pass logstr = f' {idx + start:2d}: {name}' if estimated_sizes: es = estimated_sizes[idx + start] logstr += f', estimated size: {es}' print(logstr) if self.loss_fn: try: print(f' loss: {self.loss_fn.__name__}') except AttributeError: print(f' loss: {self.loss_fn.__class__.__name__}') deepspeed.comm.barrier() self._set_bounds(start=self.parts[stage_id], stop=self.parts[stage_id + 1]) class DPOTrainSchedule(PipeSchedule): """Train schedule for DPO. Does an extra forward pass for the reference logits.""" def steps(self): prev_micro_batch_id = -1 total_steps = 2 * (self.micro_batches + self.stages - 1) forward_step_id = 0 ref_logits_buf = self.num_pipe_buffers() - 1 for step_id in range(total_steps): # Map the step of the pipeline to the micro-batch id and also whether it is a # forward or backward pass step. micro_batch_id, is_forward = self._step_to_micro_batch(step_id) if self._valid_micro_batch(prev_micro_batch_id): prev_buffer = self._buffer_idx(prev_micro_batch_id) if self._valid_micro_batch(micro_batch_id): curr_buffer = self._buffer_idx(micro_batch_id) cmds = [] # Exchange activations if is_forward: if self._valid_micro_batch(prev_micro_batch_id) and self._valid_stage(self.prev_stage): cmds.append(SendGrad(prev_buffer)) if self._valid_micro_batch(micro_batch_id) and self._valid_stage(self.prev_stage): cmds.append(RecvActivation(ref_logits_buf)) cmds.append(RecvActivation(curr_buffer)) else: if self._valid_micro_batch(micro_batch_id) and self._valid_stage(self.next_stage): cmds.append(RecvGrad(curr_buffer)) if self._valid_micro_batch(prev_micro_batch_id) and self._valid_stage(self.next_stage): cmds.append(SendActivation(ref_logits_buf)) cmds.append(SendActivation(prev_buffer)) # First/last stage loads if self.stage_id == 0 or self.stage_id == self.stages - 1: if is_forward and self._valid_micro_batch(micro_batch_id): # Load for normal forward and reference logits forward. cmds.append(LoadMicroBatchMultipleBuffers(curr_buffer, ref_logits_buf)) # Computation if self._valid_micro_batch(micro_batch_id): if is_forward: # Reference logits forward. cmds.append(ReferenceLogitsForwardPass(ref_logits_buf)) cmds.append(ForwardPass(curr_buffer)) forward_step_id += 1 else: cmds.append(BackwardPass(curr_buffer)) # Model step at the end of the batch if step_id == total_steps - 1: cmds.append(ReduceTiedGrads()) cmds.append(ReduceGrads()) cmds.append(OptimizerStep()) # Prepare state for next time prev_micro_batch_id = micro_batch_id yield cmds def num_pipe_buffers(self): buffers = min(self.stages - self.stage_id, self.micro_batches) # +1 buffer for reference logits forward pass. # Unlike inference, we only need 1 buffer, since alternating forward/backward passes means a stage # is never sending and receiving activations on the same step. return max(2, buffers) + 1 def _step_to_micro_batch(self, step_id): if _is_even(step_id) and _is_even(self.stage_id): micro_batch_id = self._even_step_forward_id(step_id) is_forward = True elif _is_odd(step_id) and _is_odd(self.stage_id): micro_batch_id = self._odd_step_forward_id(step_id) is_forward = True elif _is_even(step_id) and _is_odd(self.stage_id): micro_batch_id = self._even_step_backward_id(step_id) is_forward = False elif _is_odd(step_id) and _is_even(self.stage_id): micro_batch_id = self._odd_step_backward_id(step_id) is_forward = False else: assert False return micro_batch_id, is_forward def _even_step_forward_id(self, step_id): base = step_id // 2 micro_batch_id = int(base - self.stage_id // 2) return micro_batch_id def _odd_step_forward_id(self, step_id): base = (step_id - 1) // 2 micro_batch_id = int(base - self.stage_id // 2) return micro_batch_id def _even_step_backward_id(self, step_id): base = step_id // 2 micro_batch_id = int(base - self.stages + (self.stage_id + 1) // 2) return micro_batch_id def _odd_step_backward_id(self, step_id): base = ((step_id - 1) // 2) - self.stages + 1 micro_batch_id = int(base + self.stage_id // 2) return micro_batch_id # Override to account for the extra buffer used for reference logit forward pass. def _buffer_idx(self, micro_batch_id): assert self._valid_micro_batch(micro_batch_id) return micro_batch_id % (self.num_pipe_buffers() - 1) class DPOInferenceSchedule(PipeSchedule): def steps(self): total_steps = self.micro_batches + self.stages - 1 for step_id in range(total_steps): cmds = [] micro_batch_id = step_id - self.stage_id # Alternate send/recv buffers if _is_even(self.stage_id): recv_buf = step_id % 2 send_buf = (step_id + 1) % 2 else: recv_buf = (step_id + 1) % 2 send_buf = step_id % 2 ref_recv_buf = recv_buf + 2 ref_send_buf = send_buf + 2 if self.is_first_stage or self.is_last_stage: if self._valid_micro_batch(micro_batch_id): cmds.append(LoadMicroBatchMultipleBuffers(recv_buf, ref_recv_buf)) if _is_even(self.stage_id): if self._valid_stage(self.next_stage): if self._valid_micro_batch(micro_batch_id - 1): cmds.append(SendActivation(ref_send_buf)) cmds.append(SendActivation(send_buf)) if self._valid_stage(self.prev_stage): if self._valid_micro_batch(micro_batch_id): cmds.append(RecvActivation(ref_recv_buf)) cmds.append(RecvActivation(recv_buf)) else: if self._valid_stage(self.prev_stage): if self._valid_micro_batch(micro_batch_id): cmds.append(RecvActivation(ref_recv_buf)) cmds.append(RecvActivation(recv_buf)) if self._valid_stage(self.next_stage): if self._valid_micro_batch(micro_batch_id - 1): cmds.append(SendActivation(ref_send_buf)) cmds.append(SendActivation(send_buf)) if self._valid_micro_batch(micro_batch_id): cmds.append(ReferenceLogitsForwardPass(ref_recv_buf)) cmds.append(ForwardPass(recv_buf)) yield cmds def num_pipe_buffers(self): return 4 ================================================ FILE: utils/hqq_utils.py ================================================ from dataclasses import asdict, dataclass, field from typing import Any import torch import transformers from hqq.core import quantize as hqq_quantize from torch import nn import peft from utils.utils import DTYPE_MAP # Monkeypatch PEFT so that target_modules='all-linear' targets the HQQLinear layers, which are not # subclasses of nn.Linear, unlike BNB. def _maybe_include_all_linear_layers(peft_config: peft.PeftConfig, model: nn.Module) -> peft.PeftConfig: """ Helper function to update `target_modules` to all linear/Conv1D layers if provided as 'all-linear'. Adapted from the QLoRA repository: https://github.com/artidoro/qlora/blob/main/qlora.py """ # if `target_modules` is a string, convert to lower case and check if it matches "all-linear" if not ( isinstance(peft_config.target_modules, str) and peft_config.target_modules.lower() == peft.tuners.tuners_utils.INCLUDE_LINEAR_LAYERS_SHORTHAND ): return peft_config if not isinstance(model, transformers.PreTrainedModel): raise ValueError( f'Only instances of PreTrainedModel support `target_modules={peft.tuners.tuners_utils.INCLUDE_LINEAR_LAYERS_SHORTHAND!r}`' ) # add HQQLinear linear_classes = (torch.nn.Linear, transformers.pytorch_utils.Conv1D, hqq_quantize.HQQLinear) linear_module_names = set() for name, module in model.named_modules(): # match with all linear classes. if isinstance(module, linear_classes): names = name.rsplit('.', 1)[-1] # get the base name linear_module_names.add(names) # ignore the last classification head for text generation models output_emb = model.get_output_embeddings() if output_emb is not None: last_module_name = [name for name, module in model.named_modules() if module is output_emb][0] linear_module_names -= {last_module_name} peft_config.target_modules = linear_module_names return peft_config peft.tuners.tuners_utils._maybe_include_all_linear_layers = _maybe_include_all_linear_layers @dataclass class CustomHQQConfig: nbits: int = 4 group_size: int = 64 view_as_float: bool = False axis: int = 0 dynamic_config: dict[str, Any] = field(default_factory=dict) skip_modules: list[str] = field(default_factory=lambda: ['lm_head']) compute_dtype: str = 'float32' def __post_init__(self): self.compute_dtype = DTYPE_MAP[self.compute_dtype] def use_aten(self): return self.axis == 0 and all(d.get('axis', self.axis) == 0 for d in self.dynamic_config.values()) def get_dict(self, full_name): """Get final config dict to use for quantization, for module with full_name.""" kwargs = asdict(self) kwargs.pop('compute_dtype') kwargs.pop('skip_modules') dynamic_config = kwargs.pop('dynamic_config') for key, value in dynamic_config.items(): if key in full_name: kwargs.update(value) break return hqq_quantize.BaseQuantizeConfig(**kwargs) ================================================ FILE: utils/saver.py ================================================ import glob import json import os import shutil import sys import time from pathlib import Path import deepspeed import torch import transformers from safetensors.torch import save_file from utils.utils import DTYPE_MAP, is_main_process last_checkpoint_time = None def need_to_checkpoint(config): if 'checkpoint_every_n_minutes' not in config: return False global last_checkpoint_time checkpoint = False # rank 0 tracks if we need to checkpoint, broadcasts to everyone else if is_main_process(): current_time = time.time() if last_checkpoint_time is None: last_checkpoint_time = current_time elif (current_time - last_checkpoint_time) / 60 > config['checkpoint_every_n_minutes']: checkpoint = True last_checkpoint_time = current_time result = [checkpoint] torch.distributed.broadcast_object_list(result, src=0) return result[0] def convert_state_dict_dtype(state_dict, dtype): for key, v in state_dict.items(): state_dict[key] = v.to(device='cpu', dtype=DTYPE_MAP[dtype]) class Saver: def __init__(self, model_engine, pipeline_model, train_dataloader, lora_config, save_root, args, config): self.model_engine = model_engine self.pipeline_model = pipeline_model self.train_dataloader = train_dataloader self.lora_config = lora_config self.save_root = save_root + '/' if save_root[-1] != '/' else save_root self.args = args self.config = config self.keep_states = config.get('keep_states', -1) self.checkpoint_on_save = config.get('checkpoint_on_save', False) self.chrono_states = { 'step': [], 'global_step': [], } # Load best loss from disk, if found, and if a best_loss model dir exists self.best_loss = None best_loss_path = os.path.join(self.save_root, 'best_loss.txt') if os.path.exists(best_loss_path) and os.path.isdir(os.path.join(self.save_root, 'best_loss')): with open(best_loss_path) as f: self.best_loss = float(f.read()) print(f'Loaded best loss from disk: {self.best_loss}') # TODO: this is pretty hacky. Is there a way to get the state_dict from the lora model directly, # but still know which layers the given pipeline parallel stage actually trained? def save_lora(self, name): dp_id = self.model_engine.grid.get_data_parallel_rank() stage_id = self.model_engine.grid.get_pipe_parallel_rank() save_dir = self.save_root + name tmp_dir = os.path.join(save_dir, 'tmp') if dp_id == 0 and stage_id == 0: os.makedirs(tmp_dir, exist_ok=False) deepspeed.comm.barrier() if dp_id == 0: partial_state_dict = {} for name, p in self.pipeline_model.named_parameters(): if p.requires_grad: if not hasattr(p, 'original_name'): print( f'WARNING: parameter {name} requires_grad but does not have original_name. Not saving it.' ) continue partial_state_dict[p.original_name.replace('.default', '').replace('.modules_to_save', '')] = ( p.detach() ) if 'save_dtype' in self.config: convert_state_dict_dtype(partial_state_dict, self.config['save_dtype']) torch.save(partial_state_dict, os.path.join(tmp_dir, f'state_dict_{stage_id}.bin')) deepspeed.comm.barrier() if dp_id == 0 and stage_id == 0: state_dict = {} for path in glob.glob(os.path.join(tmp_dir, '*.bin')): state_dict.update(torch.load(path, map_location='cpu')) torch.save(state_dict, os.path.join(save_dir, 'adapter_model.bin')) self.lora_config.save_pretrained(save_dir) shutil.copy(self.args.config, save_dir) if hasattr(self.args, 'deepspeed_config') and self.args.deepspeed_config is not None: shutil.copy(self.args.deepspeed_config, save_dir) self.safe_rmtree(tmp_dir) def save_full_model(self, name, max_shard_size='5GB'): dp_id = self.model_engine.grid.get_data_parallel_rank() stage_id = self.model_engine.grid.get_pipe_parallel_rank() save_dir = self.save_root + name tmp_dir = os.path.join(save_dir, 'tmp') if dp_id == 0 and stage_id == 0: os.makedirs(tmp_dir, exist_ok=False) deepspeed.comm.barrier() if dp_id == 0: # With BF16_Optimizer, we get pickle errors unless we do p.detach(). I have no idea why. partial_state_dict = {p.original_name: p.detach() for p in self.pipeline_model.parameters()} if 'save_dtype' in self.config: convert_state_dict_dtype(partial_state_dict, self.config['save_dtype']) torch.save(partial_state_dict, os.path.join(tmp_dir, f'state_dict_{stage_id}.bin')) deepspeed.comm.barrier() if dp_id == 0 and stage_id == 0: state_dict = {} for path in glob.glob(os.path.join(tmp_dir, '*.bin')): state_dict.update(torch.load(path, map_location='cpu')) shards, index = transformers.modeling_utils.shard_checkpoint( state_dict, max_shard_size=max_shard_size, weights_name='model.safetensors' ) for shard_file, shard in shards.items(): save_file(shard, os.path.join(save_dir, shard_file), metadata={'format': 'pt'}) if index is not None: save_index_file = 'model.safetensors.index.json' save_index_file = os.path.join(save_dir, save_index_file) # Save the index as well with open(save_index_file, 'w', encoding='utf-8') as f: content = json.dumps(index, indent=2, sort_keys=True) + '\n' f.write(content) shutil.copy(self.args.config, save_dir) if hasattr(self.args, 'deepspeed_config') and self.args.deepspeed_config is not None: shutil.copy(self.args.deepspeed_config, save_dir) additional_files_to_copy = [ 'added_tokens.json', 'config.json', 'generation_config.json', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json', 'tokenizer.model', ] for path in glob.glob(os.path.join(self.config['model'], '*')): if os.path.basename(path) in additional_files_to_copy: shutil.copy(path, save_dir) self.safe_rmtree(tmp_dir) def will_save(self, type, name): if self.keep_states <= 0 or not is_main_process(): return if type == 'step': self.chrono_states['step'].append(name) if len(self.chrono_states['step']) > self.keep_states: print(f'Deleting {self.chrono_states["step"][0]}') self.safe_rmtree(os.path.join(self.save_root, self.chrono_states['step'].pop(0))) elif type == 'global_step': self.chrono_states['global_step'].append(name) if len(self.chrono_states['global_step']) > self.keep_states: print(f'Deleting {self.chrono_states["global_step"][0]}') self.safe_rmtree(os.path.join(self.save_root, self.chrono_states['global_step'].pop(0))) else: raise ValueError(f'Unknown save type: {type}') def save_model(self, name): # ignore epoch saves for chrono_states if name.startswith('step'): self.will_save('step', name) self.save_full_model(name) if self.lora_config is None else self.save_lora(name) def save_checkpoint(self, step): self.will_save('global_step', f'global_step{step}') self.model_engine.save_checkpoint( self.save_root, client_state={ 'step': step, 'custom_loader': self.train_dataloader.state_dict(), }, save_latest=True, exclude_frozen_parameters=True, ) def process_epoch(self, epoch, step): save_every_n_epochs = self.config.get('save_every_n_epochs', 1) save_checkpoint_on_epoch_end = self.config.get('save_checkpoint_on_epoch_end', True) if self.train_dataloader.epoch != epoch: if save_checkpoint_on_epoch_end: self.save_checkpoint(step) if epoch % save_every_n_epochs == 0: self.save_model(f'epoch{epoch}') epoch = self.train_dataloader.epoch if epoch > self.config['epochs']: return None if is_main_process(): print(f'Started new epoch: {epoch}') return epoch def process_step(self, step): # Look at some simple "signal files" the user can write to save and optionally quit manually should_manually_save = False should_manually_quit = False save_signal_file = Path(self.save_root) / 'save' save_quit_signal_file = Path(self.save_root) / 'save_quit' if save_signal_file.exists() and save_signal_file.is_file(): should_manually_save = True deepspeed.comm.barrier() if is_main_process(): os.remove(save_signal_file) elif save_quit_signal_file.exists() and save_quit_signal_file.is_file(): should_manually_save = True should_manually_quit = True deepspeed.comm.barrier() if is_main_process(): os.remove(save_quit_signal_file) if ('save_steps' in self.config and step % self.config['save_steps'] == 0) or should_manually_save: self.save_model(f'step{step}') if self.checkpoint_on_save and not should_manually_save: self.save_checkpoint(step) pending_save_best_loss = os.path.exists(os.path.join(self.save_root, '.pending_save_best_loss')) if pending_save_best_loss: self.save_model('best_loss') if is_main_process(): if self.old_best is not None: print( f'New best evaluation loss: {self.best_loss:.4f} from {self.old_best:.4f} (Δ{self.old_best - self.best_loss:.5f} [{100 * (1 - self.best_loss / self.old_best):.2f}%])' ) else: print(f'New best evaluation loss: {self.best_loss:.4f}') os.replace( os.path.join(self.save_root, '.pending_save_best_loss'), os.path.join(self.save_root, 'best_loss.txt'), ) if (not self.checkpoint_on_save and need_to_checkpoint(self.config)) or should_manually_save: self.save_checkpoint(step) if should_manually_quit: print('Manually quitting') sys.exit() def append_eval_results(self, loss, save_best=True): if loss is not None: if self.best_loss is None: print(f'Evaluation loss: {loss:.4f}') elif loss >= self.best_loss: print( f'Evaluation loss: {loss:.4f} (best: {self.best_loss:.4f}, Δ: {self.best_loss - loss:.5f} [{100 * (1 - loss / self.best_loss):.2f}%])' ) if self.best_loss is None or loss < self.best_loss: self.old_best = self.best_loss self.best_loss = loss if save_best: with open(os.path.join(self.save_root, '.pending_save_best_loss'), 'w') as f: f.write(str(self.best_loss)) deepspeed.comm.barrier() # Attempt to remove a directory tree using exponential backoff for retries (default max wait = 31s) def safe_rmtree(self, dir_path, max_retries=5, initial_wait_seconds=1): for attempt in range(max_retries + 1): try: shutil.rmtree(dir_path) return except OSError as e: if attempt == max_retries: raise e time.sleep(initial_wait_seconds * 2**attempt) ================================================ FILE: utils/unsloth_utils.py ================================================ # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # I (tdrussell) made a few modifications. import torch from deepspeed.runtime.activation_checkpointing.checkpointing import detach_variable class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function): """ Code licensed under LGPL Saves VRAM by smartly offloading to RAM. Tiny hit to performance, since we mask the movement via non blocking calls. """ @staticmethod @torch.amp.custom_fwd(device_type='cuda') def forward(ctx, forward_function, hidden_states, *args): saved_hidden_states = hidden_states.to('cpu', non_blocking=True) with torch.no_grad(): output = forward_function(hidden_states, *args) ctx.save_for_backward(saved_hidden_states) ctx.forward_function = forward_function ctx.args = args return output pass @staticmethod @torch.amp.custom_bwd(device_type='cuda') def backward(ctx, *grads): (hidden_states,) = ctx.saved_tensors hidden_states = hidden_states.to('cuda', non_blocking=True).detach() hidden_states.requires_grad_(True) args = detach_variable(ctx.args) inputs = (hidden_states,) + args with torch.enable_grad(): outputs = ctx.forward_function(*inputs) output_tensors = [] grad_tensors = [] for out, grad in zip(outputs, grads): if out.requires_grad: output_tensors.append(out) grad_tensors.append(grad) torch.autograd.backward(output_tensors, grad_tensors) return (None,) + tuple(input.grad for input in inputs) pass pass @torch._disable_dynamo def unsloth_checkpoint(function, *args): return Unsloth_Offloaded_Gradient_Checkpointer.apply(function, *args) ================================================ FILE: utils/utils.py ================================================ import os.path import sys from datetime import datetime import torch sys.path.insert(0, os.path.abspath('axolotl/src')) from axolotl.utils.distributed import is_main_process, zero_first # type: ignore # noqa DTYPE_MAP = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16} # Simplified logger-like printer. def log(msg): print(f'[{datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]}] [INFO] [qlora-pipe] {msg}') def eta_str(eta): eta = int(eta) if eta > 3600: return f'{eta // 3600}h{(eta % 3600) // 60}m' return f'{eta // 60}m{eta % 60}s' if eta > 60 else f'{eta}s'