Showing preview only (243K chars total). Download the full file or copy to clipboard to get everything.
Repository: tdrussell/qlora-pipe
Branch: main
Commit: 90b8e7d62034
Files: 30
Total size: 232.7 KB
Directory structure:
gitextract_z9_xsnea/
├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── examples/
│ ├── alpaca.yml
│ ├── capybara.yml
│ ├── config.toml
│ ├── config_dpo.toml
│ ├── converted_dpo_dataset.yml
│ ├── ds_config.json
│ └── ultrafeedback.yml
├── kernels/
│ ├── cross_entropy_loss.py
│ └── utils.py
├── models/
│ ├── layers.py
│ ├── models.py
│ └── pipeline_model.py
├── pyproject.toml
├── requirements.txt
├── tools/
│ ├── convert_dpo_dataset_to_chat_format.py
│ ├── convert_ds_checkpoint_to_lora.py
│ ├── merge_lora.py
│ └── test_sampling.py
├── train.py
└── utils/
├── dataloader.py
├── dataset_utils.py
├── engine.py
├── hqq_utils.py
├── saver.py
├── unsloth_utils.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
================================================
FILE: .gitmodules
================================================
[submodule "axolotl"]
path = axolotl
url = https://github.com/OpenAccess-AI-Collective/axolotl/
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2023 tdrussell
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# qlora-pipe
A pipeline parallel training script for LLMs.
Refer to the changelog at the bottom for details on updates.
## About
This is a training script I made so that I can fine-tune LLMs using my workstation with four 4090s. It is developed first and foremost for myself, with my own use cases in mind. It is scrappy and hacked together. It will likely *never* be a stable, well-supported training script like Axolotl. I am open sourcing the code in case it is useful to others, and also as a proof-of-concept that this kind of thing is possible.
That being said, if something doesn't work right, or you would like it to support some feature, feel free to raise an issue and I'll try to look at it.
## Features
- Pipeline parallel training, for efficiently training large models that cannot fit on one GPU
- Supports QLoRA, LoRA, and full fine tuning
- Quantize weights using either bitsandbytes or HQQ
- Efficient model loading. Each process only loads the layers it needs, and quantizes and moves them to the GPU layer-by-layer. This means you can load a large model on a lot of GPUs even with limited system RAM.
- Load any dataset that Axolotl can, using the same YAML config file format
- Support for "raw text" training using either a structured list of documents in a JSON file, or a single txt file
- Support for resuming training from a checkpoint, including the dataloader state, to easily allow training in a piecemeal fashion
- Useful metrics logged to Tensorboard
- Ability to specify a separate, fixed evaluation dataset
- Train on multiple datasets simultaneously, with different sampling ratios per dataset
- Models currently supported: Llama, Mistral, Mixtral, Qwen, Cohere (Command R), Phi-3 (mini and medium), Gemma 2, Gemma 3, Cohere2 (Command-A)
## Installing
Clone the repository:
```
git clone --recurse-submodules https://github.com/tdrussell/qlora-pipe
```
If you alread cloned it and forgot to do --recurse-submodules:
```
git submodule init
git submodule update
```
Install Miniconda: https://docs.conda.io/en/latest/miniconda.html
Create the environment
```
conda create -n qlora-pipe python=3.12
conda activate qlora-pipe
```
Install the dependencies:
```
pip install -r requirements.txt
```
Install nvcc:
```
conda install nvidia::cuda-nvcc
```
## Training
__Start by reading through the config files in the examples directory__. There are lots of comments explaining what the various fields do. Then, make a copy and edit it however you like. At minimum, change the paths at the top to point to your model and desired output directory. Launch the training script:
```
NCCL_P2P_DISABLE="1" NCCL_IB_DISABLE="1" deepspeed --num_gpus=1 train.py --deepspeed --config examples/config.toml
```
RTX 4000 series needs those 2 enviroment variables set. Other GPUs may not need them.
## Parallelism
Deepspeed handles pipeline- and data-parallelism. Set the --num_gpus flag to however many GPUs to want to use. The config option `pipeline_stages` determines the level of model parallelism. Then, the data parallelism is automatically set so that all GPUs are used.
For example with 8 GPUs, and pipeline_stages=4, a single instance of the model is divided across 4 GPUs. Because there are 8 GPUs total, there are then 2 data-parallel instances.
The option `gradient_accumulation_steps` in the Deepspeed JSON config file determines the amount of pipelining when using pipeline parallelism (pipeline_stages>1). The higher the value, the more the GPUs can overlap computation. For example, with gradient_accumulation_steps=1, there is a single batch that gets passed between the GPUs forward, then in reverse for the backward pass. Only 1 GPU is active at a time, the others are idle. As gradient_accumulation_steps increases, you start pipelining multiple forward/backward batches. At the beginning and end of the step, some GPUs will always be idle. So as gradient_accumulation_steps approaches infinity, you approach 100% theoretical utilization. In practice, a value of 8 or so already gives good average utilization with 2 GPUs. With more GPUs, you may want to go higher.
## Dataset configuration
There are 3 options for specifying each dataset. Set the `dataset_type` field to one of:
- axolotl
- Loads the dataset using the Axolotl codebase. Set `dataset_path` to a YAML file that contains the same dataset configuration you would use in Axolotl.
- doclist
- Set `dataset_path` to glob pattern matching one or more JSON or JSONL files. Each file should be a list of objects containing a 'text' key. For each file, all of the text is logically concatenated together, before being sliced into sequences.
- textfile
- Basically the same as doclist, except the `dataset_path` matches one or more txt files. Each text file is sliced into sequences.
You can read dataset_utils.py for details on what each of these options is doing.
You can have multiple datasets. Just add additional `[[datasets]]` entries. When using multiple datasets, there are different ways to combine them.
- `dataset_combination_mode` = 'concatenate' (the default)
- Just concatenates the datasets.
- `dataset_combination_mode` = 'interleave'
- Uses the Huggingface Datasets library `interleave_datasets()` function.
- Use the `dataset_interleave_stopping_strategy` setting to control when interleaving stops.
- 'first_exhausted': stop when a dataset runs out of examples.
- 'all_exhausted': stop when all datasets have run out of examples. This duplicates examples from smaller datasets.
- When using the 'interleave' mode, datasets can have a relative `sample_weight`, which is a positive real number. This controls the relative proportion of the datasets when they are combined.
- __IMPORTANT__: When using the 'interleave' mode, the manner in which the datasets are proportionally combined (i.e. sampled from) is affected by the `batch_size_tokens` setting:
- If `batch_size_tokens` is unset, it means you are treating each example equally. Every batch has the same number of examples, even though they may be different lengths. So, when interleaving datasets, the rows are sampled according to the relative proportions given by the `sample_weight`.
- If using `batch_size_tokens`, it means you are treating each token equally. Every batch varies the number of examples (because they might have different lengths) so that the token count is approximately constant. So, when interleaving datasets, the sampling ratios are adjusted so that the number of *tokens*, not rows, drawn from different datasets matches the `sample_weight`. This is implemented by scaling the sampling probabilities by the average length of the dataset. You can read the `combine_datasets()` function in dataset_utils.py if this is confusing.
- __Which of these should I use?__ Probably set `batch_size_tokens`. I think this is the better way to think about things, and it matches what sample packing would do. For example, in Axolotl, it is recommended to use sample packing, which packs multiple examples into a single sequence so that the sequence length is constant. This means, in the loss function, each token is being treated with equal weight, not each original row in the dataset. Using `batch_size_tokens` in this training script mimics that behavior, and thus when interleaving datasets, it samples from them so that the token ratios adhere to the sample_weight specified.
- __Example__: you have datasets A and B. B's average row length is twice that of A. A has a sample_weight of 2, B has a sample_weight of 1.
- Not setting batch_size_tokens: when interleaving, you get 2 rows of A for every row of B.
- Using batch_size_tokens: when interleaving, you get 4 rows of A for every row of B. This is because A's rows are on average half the length of B's rows, so you need twice as many as before so that the number of tokens in each matches the 2:1 ratio you specified with the sample_weight.
## On sample packing (or the lack thereof)
Sample packing is not currently implemented. Instead, there is the option `batch_size_tokens`. If this field is set, the per-device batch size is ignored, and instead the batch size is adjusted dynamically to target a fixed number of tokens per batch, per device. This was easier to implement than sample packing, and does basically the same thing. It is also efficient: if I set batch_size_tokens to a modest 10000 and train a 7B model with the Alpaca dataset, all my 4090s hit their 350W power limit cap. Unless I'm missing something (definitely possible), it seems there is no need to support sample packing.
## Floating point precision
There are different places you can specify the floating point dtype. `model_weight_dtype` controls the precision of the underlying model weights (for any weights not quantized), and `lora_weight_dtype` is for the lora weights. If you are using quantization, both bnb and hqq have options for the compute dtype as well.
If you are using 16 bit dtypes, floating point roundoff error is a potential problem. For a good overview of the problem and solutions, see [Revisiting Bfloat16 Training](https://arxiv.org/pdf/2010.06192). TLDR: the main source of precision error when training with 16 bit weights is the weight update step: $(p = p + \Delta p * lr)$. When the update is very small compared to the parameter (which is often the case), there can be significant roundoff error, including the update being entirely dropped. Mixed precision training solves this by keeping a master copy of the weights in fp32, and running all optimizer steps in fp32. Kahan summation is another solution when training in full bf16, that keeps an extra bf16 buffer for each parameter to accumulate roundoff errors so that updates are never dropped.
### Okay but how should I configure things?
- If unsure, set everything to bf16 and use the adamw_kahan optimizer type. Kahan summation is ESPECIALLY important for full fine tuning. Kahan summation requires an extra 2 bytes per trainable parameter compared to vanilla full bf16 training.
- For LoRAs, another option is setting `lora_weight_dtype` to fp32, which also makes all optimizer states fp32.
- For LoRAs only, with constant learning rate no lower than 5e-5 or so, I have seen full bf16 training with no Kahan summation mostly match fp32 or bf16 + Kahan.
- (more experimental) You may try Deepspeed's bf16 mode, but I personally don't use this. I think this does something like mixed precision, where it wraps the optimizer to keep a master copy of the parameters in fp32, as well as doing gradient accumulation and all optimizer states in fp32. This will use much more memory than full bf16 + Kahan summation.
## Changelog
### 2025-03-12
- Change how weights are loaded to avoid Transformers internal method
- Support Gemma 3
### 2025-01-30
- Add pretokenized dataset option.
- Update layers to work with the new way to pass position embeddings in HF Transformers. Please update Transformers to the latest version or you will get errors.
### 2024-12-29
- Add DPO training. The examples directory has a DPO example.
### 2024-07-02
- Add Gemma-2 support.
### 2024-06-20
- Add adamw_kahan optimzer type and make it the default in the example.
### 2024-05-19
**The old config file format will break.** Quantization is configured slightly differently now. Read examples/config_7b.toml. It's only a few lines to change.
- Change how quantization is configured. Quantization is now its own table in the TOML file.
- Add HQQ quantization.
### 2024-04-28
- Add llama3 instruction formatting option when loading a ShareGPT formatted dataset using Axolotl.
- Automatically add BOS token for Llama 3.
- Add option for Unsloth activation checkpointing, which saves VRAM for a very small hit to performance.
### 2024-04-16
- Optimizer is now specified in the config.toml file.
- Can use AdamW8Bit optimizer.
- MLP offloading works again. For MoE, can offload a specified number of experts.
- Can have separate dtype for saved files.
- Cohere model support (command-r)
### 2024-04-07
Make sure to update requirements! Axolotl does some dynamic importing, so things will break in a very hard to diagnose way if you don't have a new dependency that was added.
- Removed the need for manually specifying cache directories for datasets. All dataset processing uses the Huggingface Datasets library and takes advantage of the automatic caching that it provides.
- Added the ability to specify multiple datasets, with different ways to combine them. __This breaks the old config format for datasets.__ Refer to the example config for what it should look like now.
================================================
FILE: examples/alpaca.yml
================================================
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
================================================
FILE: examples/capybara.yml
================================================
chat_template: llama3
datasets:
- path: ssmi153/Capybara-ShareGPT
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
================================================
FILE: examples/config.toml
================================================
# Paths
model = '/data2/models/Meta-Llama-3.1-8B'
output_dir = '/data/training_runs/llama3_8b_example'
# Lora configuration
# can use full_fine_tune=true and no quantization to train the whole model instead of a LoRA
#full_fine_tune = true
lora_rank = 64
lora_alpha = 64
lora_dropout = 0.05
# Train only specific modules. This is passed to the parameter of the same name in the LoraConfig.
# If not set, adapt all linear modules.
# Note, this ALSO affects full fine tuning. In that case, if this is set, only weights containing one
# of these keys as substring will have requires_grad. If not set everything is trained.
#target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
# can specify layers to adapt with LoRA if you want
#layers_to_transform = '16:31'
# Training settings
epochs = 2
lr_scheduler = 'cosine' # can also be 'constant'
warmup_steps = 100
# Batch size of a single forward/backward pass for one GPU.
micro_batch_size_per_gpu = 1
# Dynamic batch size, targeting this many tokens per batch, per device.
# If set, completely ignores micro_batch_size_per_gpu.
# Can be thought of as a replacement for sample packing.
batch_size_tokens = 5000
# Number of pipeline parallel stages, must evenly divide the number of GPUs you launch the script with. The lower this is, the more GPU VRAM you need, but the faster the training will be.
# A value of 1 means the model will be loaded on every GPU in full, each GPU running its own training batches.
# A value of 2 means the model will be split into half, the halves loaded evenly across the (2, 4, 6, 8, ...) GPUs, where GPUs will work in pairs on each batch. (And so on.)
pipeline_stages = 1
# Number of micro-batches sent through the pipeline for each training step.
# If pipeline_stages > 1, a higher GAS means better GPU utilization due to smaller pipeline bubbles (where GPUs aren't overlapping computation).
gradient_accumulation_steps = 4
# Grad norm clipping.
gradient_clipping = 1.0
# might be useful if resuming from a checkpoint and you want to change the LR and force it to something
#force_constant_lr = 5e-5
# hard clamp the magnitude of the LoRA weights
#scale_weight_norms = 1.0
# for Mixtral, set the load balancing coefficient
#load_balancing_loss_coef = 0.02
# Eval settings
eval_steps = 100 # how often to run eval
eval_before_first_step = true # do an eval before any training happens
eval_after_last_step = false # do a final eval after the training completes
# Performance settings
logging_steps = 10 # how often to log in Tensorboard
save_steps = 200 # how often to save the model
checkpoint_every_n_minutes = 60 # how frequently to checkpoint training states (used for resuming training)
# checkpoint_on_save = true # alternative to the above, this will cause a checkpoint save every time a regular save occurs (note: this setting takes precedence over checkpoint_every_n_minutes)
# dtype to load the underlying model weights in
model_weight_dtype = 'bfloat16'
# dtype for the LoRA weights
lora_weight_dtype = 'bfloat16'
# Can have the saved weights be different dtype. Don't need to set this. Could be useful for
# training in float32 but saving with float16.
#save_dtype = 'bfloat16'
# Keep this number of stepXXXX (model saves) and global_stepXXX (checkpoint saves) and delete the rest
# (this only applies to the current training session, and resumed training sessions will not touch
# old saves)
keep_states = 3
# sort examples by length before dividing them into batches
# this makes all examples in a batch approximately the same length, to minimize padding
# the batches are still shuffled after that
# you should probably always have this set to true
group_by_length = true
# This can also be 'unsloth' to offload hidden states to CPU, saving potentially a lot of VRAM
# for a minor performance hit.
# Example: 4x4090, PCIE 3.0 16x, pipeline_stages=4, training QLoRA on Llama 3 70B with 4096 sequence length.
# true: 75s step time, 19.7G peak per-GPU VRAM usage.
# 'unsloth': 78s step time, 16.2G peak per-GPU VRAM usage.
activation_checkpointing = true
# Keep MLP weights on system RAM until they are needed. Can save a ton of VRAM with a
# moderate hit to performance. If using an MoE model, this can also be an integer, in
# which case only that many experts are offloaded (tradeoff between VRAM and speed).
#offload_mlp_to_cpu = true
# Resume a prior run
# if true, we attempt to resume training from the most recent directory inside output_dir (the directory names are timestamps)
# so, to resume, just run the exact same command but set this to true first
resume_from_checkpoint = false
# Loading the optimizer states seems to cause some kind of unavoidable VRAM memory leak.
# It's very small, only about 0.2 GB in cases I've seen. But if you are very close to the
# limit, it can cause resuming from checkpoint to OOM. As a last resort, you can uncomment
# this to not load the optimizer states and hopefully the resumption won't OOM.
#load_optimizer_states = false
# Dataset configuration
# How to combine multiple datasets if you have more than one.
# Can be 'concatenate' or 'interleave'. Will be 'concatenate' if not set.
dataset_combination_mode = 'interleave'
# When to stop interleaving datasets when using mode 'interleave'. Either 'first_exhausted' or 'all_exhausted'.
# Default if not set: 'first_exhausted'
dataset_interleave_stopping_strategy = 'all_exhausted'
# Can set this lower than training, so we don't drop as many examples when trying to make equal-sized batches.
# Default if not set: same as training GAS.
eval_gradient_accumulation_steps = 1
# bitsandbytes 4 bit quantization. The parameters here become arguments to Transformers BitsAndBytesConfig.
[quantization.bnb]
load_in_4bit = true
bnb_4bit_use_double_quant = false
bnb_4bit_compute_dtype = 'bfloat16'
# HQQ quantization. The parameters here become arguments to CustomHQQConfig.
# [quantization.hqq]
# nbits = 4
# group_size = 64
# compute_dtype = 'bfloat16'
# (Optional) You can override the quant params for certain modules. This does substring matching, e.g. if 'gate_proj'
# is a substring of the full module name, anything specified overwrites the defaults in [quantization.hqq].
# [quantization.hqq.dynamic_config]
# gate_proj = {nbits = 2, group_size = 16, quant_zero = true, quant_scale = true}
# up_proj = {nbits = 2, group_size = 16, quant_zero = true, quant_scale = true}
# down_proj = {nbits = 2, group_size = 16, quant_zero = true, quant_scale = true}
[optimizer]
# options: adamw_kahan, AdamW, AdamW8bit
type = 'adamw_kahan'
lr = 5e-5
beta1 = 0.9
beta2 = 0.99
weight_decay = 0.1
[[datasets]]
# Arbitrary name, used only for separately logging eval metrics. Will be dataset0, dataset1, etc if not set.
name = 'alpaca'
dataset_type = 'axolotl'
dataset_path = 'examples/alpaca.yml'
sequence_len = 2048
eval_size = 0.02
# Relative sampling weight, when using combination mode 'interleave'. Will be 1 if not set.
sample_weight = 1
[[datasets]]
name = 'capybara'
dataset_type = 'axolotl'
dataset_path = 'examples/capybara.yml'
sequence_len = 2048
eval_size = 0.02
sample_weight = 1.5
# In addition to using eval_size which splits off some of the dataset, we can have completely separate datasets for eval.
# This can be useful if you're training on raw text data, so that the eval set remains completely fixed, even if
# you change training sequence_len, etc.
# This is just an example, typically you wouldn't have this overlap a training dataset.
# [[eval_datasets]]
# name = 'capybara'
# dataset_type = 'axolotl'
# dataset_path = 'examples/capybara.yml'
# sequence_len = 2048
================================================
FILE: examples/config_dpo.toml
================================================
# Paths
model = '/data2/models/Meta-Llama-3.1-8B-Instruct'
output_dir = '/data/training_runs/llama3_8b_dpo_example'
lora_rank = 64
lora_alpha = 64
lora_dropout = 0.05
epochs = 2
lr_scheduler = 'constant'
warmup_steps = 100
batch_size_tokens = 5000
micro_batch_size_per_gpu = 1
pipeline_stages = 1
gradient_accumulation_steps = 4
gradient_clipping = 1.0
eval_steps = 100
eval_before_first_step = true
eval_after_last_step = false
logging_steps = 10
save_steps = 200
checkpoint_every_n_minutes = 60
model_weight_dtype = 'bfloat16'
lora_weight_dtype = 'bfloat16'
group_by_length = true
activation_checkpointing = true
eval_gradient_accumulation_steps = 1
[rl]
method = 'dpo'
dpo_beta = 0.02
# [quantization.bnb]
# load_in_4bit = true
# bnb_4bit_use_double_quant = false
# bnb_4bit_compute_dtype = 'bfloat16'
[optimizer]
type = 'adamw_kahan'
lr = 5e-5
beta1 = 0.9
beta2 = 0.99
weight_decay = 0.1
[[datasets]]
name = 'ultrafeedback'
dataset_type = 'axolotl'
dataset_path = 'examples/ultrafeedback.yml'
sequence_len = 4096
eval_size = 0.01
================================================
FILE: examples/converted_dpo_dataset.yml
================================================
# Some DPO datasets are not in conversation format.
# They need to be in conversation format to load them in this script. You can convert them:
# python tools/convert_dpo_dataset_to_chat_format.py unalignment/toxic-dpo-v0.2 ~/data/toxic-dpo-v0.2-converted
chat_template: llama3
datasets:
- path: json
data_files:
- /home/anon/data/toxic-dpo-v0.2-converted/train.json
split: train
type: orpo.chat_template
================================================
FILE: examples/ds_config.json
================================================
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"gradient_clipping": 1.0,
"steps_per_print": 1
}
================================================
FILE: examples/ultrafeedback.yml
================================================
# This dataset is already in a format that can be directly loaded by the orpo.chat_template type.
chat_template: llama3
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned
type: orpo.chat_template
================================================
FILE: kernels/cross_entropy_loss.py
================================================
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import triton
import triton.language as tl
from .utils import MAX_FUSED_SIZE, calculate_settings, device_warp_size
@triton.heuristics(
{
'DO_LOGIT_SCALING': lambda args: args['DO_LOGIT_SCALING'],
}
)
@triton.jit
def _cross_entropy_forward(
logits_ptr,
logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
DO_LOGIT_SCALING: tl.constexpr,
LOGIT_SCALE: tl.constexpr,
):
"""
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
Pi = exp(xi) / sum(exp(xi))
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
= -y [ x - log[sum(exp(x))] ]
= y * (log[sum(exp(x))] - x)
If y == 0: CE_i = 0
If y == 1: CE_i = logsumexp - x
logsumexp is also stable
Take y = log[sum(exp(x))]
exp(y) = sum(exp(x))
exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x
exp(y) = exp(c)*sum(exp(x - c))
y = log(exp(c)*sum(exp(x - c)))
y = c + log[sum(exp(x - c))]
This means we can set c = max(x) to make sure
exp(x - c) always is exp(x - max(x)).
This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
"""
row_idx = tl.program_id(0)
logits_ptr += row_idx * logits_row_stride.to(tl.int64)
loss_ptr += row_idx
logsumexp_ptr += row_idx
labels_ptr += row_idx
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float('inf')).to(tl.float32)
if DO_LOGIT_SCALING:
# Logit scaling: s * x
logits = LOGIT_SCALE * logits
pass
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
if DO_LOGIT_SCALING:
# Logit scaling: s * x
x = LOGIT_SCALE * x
pass
loss = logsumexp - x
else:
loss = 0.0
tl.store(logsumexp_ptr, logsumexp)
tl.store(loss_ptr, loss)
pass
@triton.heuristics(
{
'DO_LOGIT_SCALING': lambda args: args['DO_LOGIT_SCALING'],
}
)
@triton.jit
def _chunked_cross_entropy_forward(
logits_ptr,
logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE: tl.constexpr,
N_CHUNKS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
DO_LOGIT_SCALING: tl.constexpr,
LOGIT_SCALE: tl.constexpr,
):
"""
256K vocab divided in 4 chunks
|-65536-| |-65536-| |-65536-| |-65536-|
|-------| |-------| |-------| |-------|
|-------| |-------| |-------| |-------|
If y == 0: CE_i = 0
If y == 1: CE_i = logsumexp - x
Notice we can do logsumexp for each chunk and then
logsumexp[chunk_sum(logsumexp)] == logsumexp
chunk_sum = log[chunk_sum(logsumexp)]
= log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]
= log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]
= log[sum(exp(a)) + ... + sum(exp(z))]
= logsumexp(x)
This means we can perform a logsumexp for each chunk, then do a
final logsumexp reduction!
Ie do: logsumexp(chunked_logsumexp) - x
"""
row_idx = tl.program_id(0)
chunk_idx = tl.program_id(1)
logits_ptr += row_idx * logits_row_stride.to(tl.int64)
loss_ptr += row_idx
logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
labels_ptr += row_idx
col_offsets = chunk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float('inf')).to(tl.float32)
if DO_LOGIT_SCALING:
# Logit scaling: s * x
logits = LOGIT_SCALE * logits
pass
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
if chunk_idx == 0:
# logsumexp(chunked_logsumexp) - x
# Do the -x separately
if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
if DO_LOGIT_SCALING:
# Logit scaling: s * x
x = LOGIT_SCALE * x
pass
loss = -1.0 * x
else:
loss = 0.0
tl.store(loss_ptr, loss)
pass
tl.store(logsumexp_ptr, logsumexp)
pass
@triton.heuristics(
{
'DO_LOGIT_SCALING': lambda args: args['DO_LOGIT_SCALING'],
}
)
@triton.jit
def _cross_entropy_backward(
logits_ptr,
logits_row_stride,
dloss_ptr,
dloss_row_stride,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
DO_LOGIT_SCALING: tl.constexpr,
LOGIT_SCALE: tl.constexpr,
):
"""
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
From https://en.wikipedia.org/wiki/LogSumExp
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
If y == 0: dC/dx = 0
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
"""
row_idx = tl.program_id(0)
block_idx = tl.program_id(1)
logits_ptr += row_idx * logits_row_stride.to(tl.int64)
dloss_ptr += row_idx * dloss_row_stride
col_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
if label_idx != -100:
dloss = tl.load(dloss_ptr)
else:
dloss = 0.0
x = tl.load(logits_ptr + col_offsets, mask=mask, other=-float('inf')).to(tl.float32)
if DO_LOGIT_SCALING:
# d/dx [s * x] = s
x = LOGIT_SCALE * x
pass
logsumexp = tl.load(logsumexp_ptr + row_idx)
y = tl.exp(x - logsumexp)
y = tl.where(
col_offsets == label_idx,
y - 1.0, # exp(x - logsumexp) - 1
y, # exp(x - logsumexp)
)
# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
if DO_LOGIT_SCALING:
# d/dx [s * x] = s
y = LOGIT_SCALE * y
pass
tl.store(logits_ptr + col_offsets, dloss * y, mask=mask)
pass
class Fast_CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, logit_scale=1.0):
n_rows, vocab_size = logits.shape
div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
n_chunks = div + (mod != 0)
losses = torch.empty(n_rows, dtype=torch.float32, device='cuda')
if n_chunks == 1:
# For small vocabs <= 65336 like Llama, Mistral
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
logsumexp = torch.empty(n_rows, dtype=torch.float32, device='cuda')
_cross_entropy_forward[(n_rows,)](
logits,
logits.stride(0),
losses,
logsumexp,
labels,
VOCAB_SIZE=vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
DO_LOGIT_SCALING=(logit_scale != 1.0),
LOGIT_SCALE=logit_scale,
num_warps=num_warps,
)
else:
# For large vocabs > 65336 like Gemma 256K
logsumexp = torch.empty(
(
n_rows,
n_chunks,
),
dtype=torch.float32,
device='cuda',
)
_chunked_cross_entropy_forward[
(
n_rows,
n_chunks,
)
](
logits,
logits.stride(0),
losses,
logsumexp,
labels,
VOCAB_SIZE=vocab_size,
N_CHUNKS=n_chunks,
BLOCK_SIZE=MAX_FUSED_SIZE,
DO_LOGIT_SCALING=(logit_scale != 1.0),
LOGIT_SCALE=logit_scale,
num_warps=32 if device_warp_size() < 64 else 16,
)
# logsumexp(chunked_logsumexp) - x
# Do the -x separately
logsumexp = torch.logsumexp(logsumexp, dim=1) # Row sum
losses += logsumexp
losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!
pass
ctx.save_for_backward(logits, logsumexp, labels)
ctx.logit_scale = logit_scale
return losses
pass
@staticmethod
def backward(ctx, dlosses):
logits, logsumexp, labels = ctx.saved_tensors
n_rows, vocab_size = logits.shape
BLOCK_SIZE = 4096
div, mod = divmod(vocab_size, BLOCK_SIZE)
n_blocks = div + (mod != 0)
_cross_entropy_backward[
(
n_rows,
n_blocks,
)
](
logits,
logits.stride(0),
dlosses,
dlosses.stride(0),
logsumexp,
labels,
VOCAB_SIZE=vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
DO_LOGIT_SCALING=(ctx.logit_scale != 1.0),
LOGIT_SCALE=ctx.logit_scale,
num_warps=8,
)
return (
logits,
None,
None,
)
pass
pass
def fast_cross_entropy_loss(logits, labels, logit_scale=1.0):
"""
Arguments:
logits: (batch, seq_len, vocab_size)
labels: (batch, seq_len,)
Returns:
losses: float
"""
batch, seq_len, d = logits.shape
assert labels.shape == (batch, seq_len)
loss = Fast_CrossEntropyLoss.apply(
logits.view(batch * seq_len, d),
labels.view(-1),
logit_scale,
)
n_items = torch.count_nonzero(labels != -100)
return loss.sum() / n_items
pass
================================================
FILE: kernels/utils.py
================================================
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ctypes
import bitsandbytes as bnb
import torch
import triton
MAX_FUSED_SIZE = 65536
next_power_of_2 = triton.next_power_of_2
def device_warp_size():
if torch.cuda.is_available() and torch.version.hip and torch.cuda.get_device_capability(0)[0] < 10:
return 64
else:
return 32
def calculate_settings(n):
BLOCK_SIZE = next_power_of_2(n)
if BLOCK_SIZE > MAX_FUSED_SIZE:
raise RuntimeError(
f'Cannot launch Triton kernel since n = {n} exceeds the maximum CUDA blocksize = {MAX_FUSED_SIZE}.'
)
warp_scalar = 32 / float(device_warp_size())
num_warps = 4
if BLOCK_SIZE >= 32768:
num_warps = 32 * warp_scalar
elif BLOCK_SIZE >= 8192:
num_warps = 16 * warp_scalar
elif BLOCK_SIZE >= 2048:
num_warps = 8 * warp_scalar
return BLOCK_SIZE, int(num_warps)
pass
get_ptr = bnb.functional.get_ptr
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
def QUANT_STATE(W):
return getattr(W, 'quant_state', None)
pass
def get_lora_parameters(proj):
# For DPO or disabled adapters
base_layer = proj.base_layer if hasattr(proj, 'base_layer') else proj
W = base_layer.weight
if not hasattr(proj, 'disable_adapters') or proj.disable_adapters or proj.merged:
return W, QUANT_STATE(W), None, None, None
pass
active_adapter = proj.active_adapters[0] if hasattr(proj, 'active_adapters') else proj.active_adapter
A = proj.lora_A[active_adapter].weight
B = proj.lora_B[active_adapter].weight
s = proj.scaling[active_adapter]
return W, QUANT_STATE(W), A, B, s
pass
def fast_dequantize(W, quant_state=None, out=None):
if quant_state is None:
return W
if type(quant_state) is not list:
# New quant_state as a class
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
absmax = quant_state.absmax
shape = quant_state.shape
dtype = quant_state.dtype
blocksize = quant_state.blocksize
offset = quant_state.offset
state2 = quant_state.state2
absmax2 = state2.absmax
code2 = state2.code
blocksize2 = state2.blocksize
else:
# Old quant_state as a list of lists
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
# Create weight matrix
if out is None:
out = torch.empty(shape, dtype=dtype, device='cuda')
else:
assert out.shape == shape
assert out.dtype == dtype
# NF4 dequantization of statistics
n_elements_absmax = absmax.numel()
out_absmax = torch.empty(n_elements_absmax, dtype=torch.float32, device='cuda')
# Do dequantization
ptr_out_absmax = get_ptr(out_absmax)
cdequantize_blockwise_fp32(
get_ptr(code2),
get_ptr(absmax),
get_ptr(absmax2),
ptr_out_absmax,
ctypes.c_int(blocksize2),
ctypes.c_int(n_elements_absmax),
)
out_absmax += offset
fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else cdequantize_blockwise_bf16_nf4
fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), ctypes.c_int(blocksize), ctypes.c_int(out.numel()))
# Careful returning transposed data
is_transposed = True if W.shape[0] == 1 else False
return out.t() if is_transposed else out
pass
def fast_gemv(X, W, quant_state, out=None):
if quant_state is None:
return torch.matmul(X, W, out=out)
# For fast X @ W where seq_len == 1
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
_, q_len, hd = X.shape
# assert(q_len == 1)
if type(quant_state) is not list:
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
absmax = quant_state.absmax
shape = quant_state.shape
dtype = quant_state.dtype
blocksize = quant_state.blocksize
stats = quant_state.code
offset = quant_state.offset
state2 = quant_state.state2
absmax2 = state2.absmax
code2 = state2.code
blocksize2 = state2.blocksize
else:
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
offset, state2 = compressed_stats
absmax2, code2, blocksize2, _, _, _, _ = state2
pass
# assert(dtype == X.dtype)
bout = shape[0]
if out is None:
out = torch.empty(
(
1,
1,
bout,
),
dtype=dtype,
device='cuda',
)
# else:
# assert(out.shape == (1, 1, bout,))
# pass
n = 1
m = shape[0]
k = shape[1]
lda = shape[0]
ldc = shape[0]
ldb = (hd + 1) // 2
m = ctypes.c_int32(m)
n = ctypes.c_int32(n)
k = ctypes.c_int32(k)
lda = ctypes.c_int32(lda)
ldb = ctypes.c_int32(ldb)
ldc = ctypes.c_int32(ldc)
df = torch.empty(absmax.shape, dtype=torch.float32, device='cuda')
cdequantize_blockwise_fp32(
get_ptr(code2),
get_ptr(absmax),
get_ptr(absmax2),
get_ptr(df),
ctypes.c_int(blocksize2),
ctypes.c_int(df.numel()),
)
df += offset
absmax = df
fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else cgemm_4bit_inference_naive_bf16
blocksize = ctypes.c_int32(blocksize)
fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), lda, ldb, ldc, blocksize)
return out
pass
def fast_linear_forward(proj, X, temp_lora=None, out=None):
W, W_quant, lora_A, lora_B, lora_S = get_lora_parameters(proj)
bsz, q_len, in_dim = X.shape
if q_len != 1:
return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
if W_quant is None:
out = torch.matmul(X, W.t(), out=out)
elif bsz == 1 and q_len == 1:
out = fast_gemv(X, W, W_quant, out=out)
else:
W = fast_dequantize(W.t(), W_quant)
out = torch.matmul(X, W, out=out)
pass
# Add in LoRA weights
if lora_A is not None:
out_dim = out.shape[2]
dtype = X.dtype
if not hasattr(lora_A, '_fast_lora'):
lora_A._fast_lora = lora_A.to(dtype)
lora_B._fast_lora = lora_B.to(dtype)
pass
if bsz == 1:
out = out.view(out_dim)
temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out=temp_lora)
out.addmv_(lora_B._fast_lora, temp_lora, alpha=lora_S)
else:
out = out.view(bsz, out_dim)
temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out=temp_lora)
out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha=lora_S)
pass
out = out.view(bsz, 1, out_dim)
pass
return out
pass
def matmul_lora(X, W, W_quant, A, B, s, out=None):
dtype = X.dtype
W = fast_dequantize(W.t(), W_quant)
if X.dim() == 3:
batch, seq_len, d = X.shape
X = X.view(-1, X.shape[-1])
reshape = True
else:
reshape = False
pass
out = torch.matmul(X, W, out=out)
if W_quant is not None:
del W
if A is not None:
# LoRA is enabled
A, B = A.t(), B.t()
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
pass
return out.view(batch, seq_len, -1) if reshape else out
pass
================================================
FILE: models/layers.py
================================================
import math
import torch
import torch.nn.functional as F
import transformers
from deepspeed.runtime.pipe import module as ds_pipe_module
from torch import nn
from kernels.cross_entropy_loss import Fast_CrossEntropyLoss
def move_data_to_device(module, device):
non_blocking = (device != 'cpu')
# handle lora
if hasattr(module, 'base_layer'):
module = module.base_layer
# handle HQQ
if hasattr(module, 'W_q'):
orig_data = module.W_q.data
module.W_q.data = orig_data.to(device, non_blocking=non_blocking)
else:
orig_data = module.weight.data
module.weight.data = orig_data.to(device, non_blocking=non_blocking)
return orig_data
def set_data(module, data):
# handle lora
if hasattr(module, 'base_layer'):
module = module.base_layer
# handle HQQ
if hasattr(module, 'W_q'):
module.W_q.data = data
else:
module.weight.data = data
def move_experts_to_device(experts, device, num_experts_to_offload):
orig_data = []
for i in range(num_experts_to_offload):
orig_w1 = move_data_to_device(experts[i].w1, device)
orig_w2 = move_data_to_device(experts[i].w2, device)
orig_w3 = move_data_to_device(experts[i].w3, device)
orig_data.append((orig_w1, orig_w2, orig_w3))
return orig_data
def set_experts_data(experts, orig_data):
for i, (orig_w1, orig_w2, orig_w3) in enumerate(orig_data):
set_data(experts[i].w1, orig_w1)
set_data(experts[i].w2, orig_w2)
set_data(experts[i].w3, orig_w3)
def entropy_fn(logits):
result = []
# There is a very wide range of chuck sizes that cause no increase in memory reported by
# nvidia-smi (Torch re-using blocks of memory?). If you try to compute it as one tensor,
# memory usage is huge. Chuck size of 128 seems good enough for now.
for logits_chuck in torch.split(logits, 128):
result.append(torch.distributions.Categorical(logits=logits_chuck).entropy())
return torch.cat(result).float()
def top_k_accuracy(logits, labels, k_list, ignore_index=-100):
keep = labels != ignore_index
labels = labels[keep].view(-1, 1)
max_k = max(k_list)
_, top_k_predictions = torch.topk(logits, max_k, dim=-1, sorted=True)
top_k_predictions = top_k_predictions[keep]
accuracies = []
for k in k_list:
accuracies.append(torch.any(top_k_predictions[:, :k] == labels, dim=-1).to(torch.float32).mean())
return accuracies
class LayerSpec(ds_pipe_module.LayerSpec):
def __init__(self, typename, *module_args, **module_kwargs):
super().__init__(typename, *module_args, **module_kwargs)
def build(self):
self.module_kwargs.pop('_estimated_size', None)
return self.typename(*self.module_args, **self.module_kwargs)
@property
def estimated_size(self):
return self.module_kwargs.get('_estimated_size', 1)
# TODO: consider using Liger-Kernel fused loss implementations. The inputs are already set up to support this.
# Would save VRAM, but some metrics could no longer be computed (e.g. entropy, accuracies).
class OutputLayer(nn.Module):
def __init__(
self,
pipeline_model,
loader_util,
lm_head,
logit_scale=1.0,
loss_type='cross_entropy_loss',
focal_loss_gamma=0,
tie_weights=None,
logit_softcapping=None,
):
super().__init__()
# Assign list to prevent registering the nn.Module
self.pipeline_model = [pipeline_model]
# Unlike the other wrapper classes, this is called lm_head and not orig. Because this is directly a
# nn.Linear layer, it needs to keep the same attribute name so quantization knows not to quantize it.
self.lm_head = lm_head
self.logit_scale = logit_scale
self.loss_type = loss_type.lower()
self.focal_loss_gamma = focal_loss_gamma
if tie_weights:
self.lm_head.weight.original_name = tie_weights
self.logit_softcapping = logit_softcapping
loader_util.load_state_dict_into_module(self)
if self.loss_type == 'cross_entropy_loss' and self.focal_loss_gamma != 0:
raise ValueError("focal_loss_gamma can't be used with 'cross_entropy_loss' function")
def forward(self, inputs):
hidden_states, labels = inputs
if self.pipeline_model[0].sampling_mode:
# When sampling only compute the last logits.
hidden_states = hidden_states[:, -1:, :]
labels = labels.to(hidden_states.device)
if self.logit_scale == 1.0:
logits = self.lm_head(hidden_states)
else:
logits = self.lm_head(self.logit_scale * hidden_states)
if self.logit_softcapping is not None and self.logit_softcapping > 0:
logits = logits / self.logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.logit_softcapping
if self.pipeline_model[0].sampling_mode:
return logits
extra_ignored_labels = torch.full((labels.shape[0], 1), -100, device=logits.device)
labels = torch.hstack((labels[..., 1:], extra_ignored_labels))
# Flatten the tokens
vocab_size = logits.size(-1)
flat_logits = logits.view(-1, vocab_size)
flat_labels = labels.view(-1)
flat_loss_mask = flat_labels >= 0
cross_entropy_loss = Fast_CrossEntropyLoss.apply(flat_logits, flat_labels)
loss = None
if self.loss_type == 'cross_entropy_loss':
cross_entropy_loss = cross_entropy_loss[flat_loss_mask]
loss_unreduced = cross_entropy_loss
elif self.loss_type == 'focal_loss':
cross_entropy_loss = cross_entropy_loss[flat_loss_mask]
# See https://arxiv.org/abs/1708.02002 (Section 3)
p = torch.exp(-cross_entropy_loss)
loss_unreduced = (1 - p) ** self.focal_loss_gamma * cross_entropy_loss
elif self.loss_type == 'focal_loss_star':
cross_entropy_loss = cross_entropy_loss[flat_loss_mask]
# See https://arxiv.org/abs/1708.02002 (Appendix A/B)
# NOTE: The use of Beta makes no sense for the multinomial case as it's invariant to translation
loss_unreduced = Fast_CrossEntropyLoss.apply(flat_logits, flat_labels, self.focal_loss_gamma)
loss_unreduced = loss_unreduced[flat_loss_mask]
loss_unreduced = loss_unreduced / self.focal_loss_gamma
elif self.loss_type == 'inverse_focal_loss':
cross_entropy_loss = cross_entropy_loss[flat_loss_mask]
# See "Rethinking Calibration of Deep Neural Networks: Do Not Be Afraid of Overconfidence" (Section 5.2)
# NOTE: The alternative of p^gamma (instead of (1+p)^gamma) might be useful for gradient ascent...
p = torch.exp(-cross_entropy_loss)
loss_unreduced = (1 + p) ** self.focal_loss_gamma * cross_entropy_loss
elif self.loss_type == 'exponentiated_cross_entropy_loss':
cross_entropy_loss = cross_entropy_loss[flat_loss_mask]
# See "Gradient as a Foundation for Building a Loss Function" (Section III.B)
# NOTE: This is a generalisation of their "Quadratic Cross-Entropy" loss (QCE: gamma=2, CE: gamma=1, etc).
loss_unreduced = cross_entropy_loss**self.focal_loss_gamma / self.focal_loss_gamma
elif self.loss_type == 'dpo':
rl_config = self.pipeline_model[0].train_config['rl']
cross_entropy_loss = cross_entropy_loss.view_as(labels) # unflatten
loss_mask = labels >= 0
logps = -(cross_entropy_loss * loss_mask).sum(-1)
half = cross_entropy_loss.size(0) // 2
chosen_logps = logps[:half]
rejected_logps = logps[half:]
if self.pipeline_model[0].dpo_reference_mode:
self.reference_chosen_logps = chosen_logps.detach()
self.reference_rejected_logps = rejected_logps.detach()
return torch.tensor(0.0, device=logits.device)
# log the language modeling loss metrics on the chosen completion
cross_entropy_loss = cross_entropy_loss[:half].flatten()[loss_mask[:half].flatten()]
hidden_states = hidden_states[:half]
loss_unreduced = cross_entropy_loss
flat_logits = logits[:half].view(-1, vocab_size)
flat_labels = labels[:half].view(-1)
flat_loss_mask = flat_labels >= 0
policy_chosen_logps = chosen_logps
policy_rejected_logps = rejected_logps
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = self.reference_chosen_logps - self.reference_rejected_logps
del self.reference_chosen_logps
del self.reference_rejected_logps
dpo_logits = pi_logratios - ref_logratios
loss = -F.logsigmoid(rl_config['dpo_beta'] * dpo_logits).mean()
else:
raise NotImplementedError(self.loss_type)
with torch.no_grad():
log_vocab_size = math.log(logits.size(-1))
entropy = entropy_fn(flat_logits)[flat_loss_mask]
# Compute normalised entropy so we can compare between models with different vocab sizes
normalised_entropy = entropy / log_vocab_size
# Compute the (negative) log-likelihood using the original *UNADJUSTED* Cross-Entropy loss.
log_likelihood = cross_entropy_loss.mean()
# Compute McFadden's Pseudo-R² metric using log(vocab_size) as the null log-likelihood.
mcfaddens_pseudo_r2 = 1 - (log_likelihood / log_vocab_size)
accuracies = top_k_accuracy(flat_logits, flat_labels, k_list=[1, 5, 20])
# Compute the norms of the (pre-logit-scaled) hidden states
hidden_state_norms = torch.norm(hidden_states.float(), dim=-1)
hidden_state_norms = hidden_state_norms.view(-1)[flat_loss_mask]
if loss is None:
# Normal language modeling loss types (e.g. not DPO)
loss = loss_unreduced.mean()
loss_unreduced = loss_unreduced.detach()
return (
loss,
loss_unreduced,
hidden_state_norms,
entropy,
normalised_entropy,
log_likelihood,
mcfaddens_pseudo_r2,
*accuracies,
)
def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float:
if isinstance(gate_logits, tuple):
compute_device = gate_logits[0].device
stacked_gate_logits = torch.stack([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
routing_weights = torch.nn.functional.softmax(stacked_gate_logits, dim=-1) # [num_layers, num_tokens, num_experts]
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1) # [num_layers, num_tokens, top_k]
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_experts
) # [num_layers, num_tokens, top_k, num_experts]
# For a given token, determine if it was routed to a given expert. Think of this as a collection of top_k-hot vectors.
expert_mask = torch.max(expert_mask, dim=-2).values.float() # [num_layers, num_tokens, num_experts]
tokens_per_layer_and_expert = torch.mean(expert_mask, dim=-2) # [num_layers, num_experts]
router_prob_per_layer_and_expert = torch.mean(routing_weights, dim=-2) # [num_layers, num_experts]
return torch.mean(tokens_per_layer_and_expert * router_prob_per_layer_and_expert) * num_experts**2
class MixtralOutputLayer(OutputLayer):
def __init__(
self,
pipeline_model,
loader_util,
lm_head,
load_balancing_loss_coef,
num_experts,
num_experts_per_tok,
**kwargs,
):
super().__init__(pipeline_model, loader_util, lm_head, **kwargs)
self.load_balancing_loss_coef = load_balancing_loss_coef
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
def forward(self, inputs):
hidden_states, labels, *router_logits = inputs
router_logits = tuple(router_logits)
outputs = super().forward((hidden_states, labels))
if self.pipeline_model[0].sampling_mode:
return outputs
if self.load_balancing_loss_coef is not None:
aux_loss = transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func(
router_logits, self.num_experts, self.num_experts_per_tok
)
alternate_aux_loss = load_balancing_loss_func(router_logits, self.num_experts, self.num_experts_per_tok)
loss = outputs[0]
loss += self.load_balancing_loss_coef * aux_loss
outputs = (loss, *outputs[1:], aux_loss, alternate_aux_loss)
return outputs
class InputLayer(nn.Module):
def __init__(self, model):
super().__init__()
self._model = [model]
self.embed_tokens = model.model.embed_tokens
self.rotary_emb = model.model.rotary_emb
self.embedding_on_cpu = not self.model.train_config['full_fine_tune']
self.model.loader_util.load_state_dict_into_module(self)
@property
def model(self):
return self._model[0]
def forward(self, inputs):
past_key_values = None
cache_position = None
use_cache = self.model.sampling_mode
input_ids, attention_mask, labels = inputs[:3]
device = input_ids.device
if self.embedding_on_cpu:
self.embed_tokens.to('cpu')
input_ids = input_ids.to('cpu')
inputs_embeds = self.embed_tokens(input_ids).to(device)
if use_cache:
past_key_values = self.model.cache
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
position_ids = cache_position.unsqueeze(0)
attention_mask = self.model.model._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, None
)
if attention_mask is None:
# attention_mask can end up being None, which means use full causal attention. But with pipeline parallelism,
# we can only pass tensors between layers. So make it an empty tensor, which will later be detected by the layers
# and converted back to None. Note: this only works now because dynamic_shape=True in the pipeline engine.
attention_mask = torch.tensor([], device=device)
# Work around a very strange Deepspeed bug. The combination of PipelineModule dynamic_shape=True, attention_mask being
# an integer dtype, and pipeline_stages>2 causes training (but not eval) to hang. So cast to float, and cast back to int
# in the layer.
if torch.is_tensor(attention_mask) and not torch.is_floating_point(attention_mask):
attention_mask = attention_mask.to(inputs_embeds.dtype)
hidden_states = inputs_embeds
if self.model.model.config.model_type == 'gemma2':
normalizer = torch.tensor(self.model.model.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
cos, sin = self.rotary_emb(hidden_states, position_ids)
output = hidden_states, attention_mask, cos, sin, cache_position, labels
# Deepspeed requirement. Float tensors must require grad.
for tensor in output:
if torch.is_floating_point(tensor):
tensor.requires_grad_(True)
return output
class LlamaRMSNormPipe(nn.Module):
def __init__(self, loader_util, orig):
super().__init__()
self.orig = orig
loader_util.load_state_dict_into_module(self)
def forward(self, inputs):
hidden_states, _, _, _, _, labels, *router_logits = inputs
return self.orig(hidden_states), labels, *router_logits
class LlamaDecoderLayerPipe(nn.Module):
def __init__(self, pipeline_model, loader_util, orig):
super().__init__()
self.pipeline_model = [pipeline_model]
self.orig = orig
self.mlp_offloaded_to_cpu = False
self.attn_implementation = pipeline_model.config._attn_implementation
loader_util.load_state_dict_into_module(self)
# A note on MLP offloading:
# We take advantage of how activation checkpointing works with reentrant checkpointing functions.
# During the forward pass, if gradients are disabled (eval or first forward pass of activation checkpointing)
# we offload the weights back to CPU at the end of the function. If gradients are enabled (second forward pass
# of activation checkpointing) we leave the weights on GPU, and use a backward hook to offload to CPU after the
# backward pass of this function is completed. This way the weights stay on the GPU for the backward pass.
def forward(self, inputs):
def move_mlp_to_cpu_hook(grad):
self.move_mlp_to_cpu()
return None
hidden_states, attention_mask, cos, sin, cache_position, labels = inputs
if self.mlp_offloaded_to_cpu:
if hidden_states.requires_grad:
hidden_states.register_hook(move_mlp_to_cpu_hook)
self.move_mlp_to_device(hidden_states.device)
kwargs = {}
if attention_mask.numel() == 0:
# We can't pass None between pipeline layers, so this signals that attention_mask should be None.
kwargs['attention_mask'] = None
else:
# We have to pass attention mask between layers as float dtype, because in certain cases training hangs otherwise. So
# now cast it back to int64 if we are using flash_attention_2.
kwargs['attention_mask'] = attention_mask.to(torch.int64) if self.attn_implementation == 'flash_attention_2' else attention_mask
kwargs['position_embeddings'] = (cos, sin)
if self.pipeline_model[0].sampling_mode:
kwargs['use_cache'] = True
kwargs['past_key_value'] = self.pipeline_model[0].cache
kwargs['cache_position'] = cache_position
result = (self.orig(hidden_states, **kwargs)[0], attention_mask, cos, sin, cache_position, labels)
if self.mlp_offloaded_to_cpu and not torch.is_grad_enabled():
self.move_mlp_to_cpu()
return result
def move_mlp_to_cpu(self):
# If it's already been moved to CPU once, just set the data to avoid a transfer.
if self.mlp_offloaded_to_cpu:
set_data(self.orig.mlp.up_proj, self.cpu_up_proj)
set_data(self.orig.mlp.down_proj, self.cpu_down_proj)
set_data(self.orig.mlp.gate_proj, self.cpu_gate_proj)
return
move_data_to_device(self.orig.mlp.up_proj, 'cpu')
move_data_to_device(self.orig.mlp.down_proj, 'cpu')
move_data_to_device(self.orig.mlp.gate_proj, 'cpu')
self.mlp_offloaded_to_cpu = True
def move_mlp_to_device(self, device):
self.cpu_up_proj = move_data_to_device(self.orig.mlp.up_proj, device)
self.cpu_down_proj = move_data_to_device(self.orig.mlp.down_proj, device)
self.cpu_gate_proj = move_data_to_device(self.orig.mlp.gate_proj, device)
class Phi3DecoderLayerPipe(LlamaDecoderLayerPipe):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def move_mlp_to_cpu(self):
if self.mlp_offloaded_to_cpu:
set_data(self.orig.mlp.gate_up_proj, self.cpu_gate_up_proj)
set_data(self.orig.mlp.down_proj, self.cpu_down_proj)
return
move_data_to_device(self.orig.mlp.gate_up_proj, 'cpu')
move_data_to_device(self.orig.mlp.down_proj, 'cpu')
self.mlp_offloaded_to_cpu = True
def move_mlp_to_device(self, device):
self.cpu_gate_up_proj = move_data_to_device(self.orig.mlp.gate_up_proj, device)
self.cpu_down_proj = move_data_to_device(self.orig.mlp.down_proj, device)
class MixtralDecoderLayerPipe(LlamaDecoderLayerPipe):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_experts_to_offload = self.pipeline_model[0].num_experts_to_offload
def forward(self, inputs):
def move_mlp_to_cpu_hook(grad):
self.move_mlp_to_cpu()
return None
hidden_states, attention_mask, cos, sin, cache_position, labels, *input_router_logits = inputs
if self.mlp_offloaded_to_cpu:
if hidden_states.requires_grad:
hidden_states.register_hook(move_mlp_to_cpu_hook)
self.move_mlp_to_device(hidden_states.device)
kwargs = {}
if attention_mask.numel() == 0:
# We can't pass None between pipeline layers, so this signals that attention_mask should be None.
kwargs['attention_mask'] = None
else:
kwargs['attention_mask'] = attention_mask.to(torch.int64) if self.attn_implementation == 'flash_attention_2' else attention_mask
kwargs['position_embeddings'] = (cos, sin)
if self.pipeline_model[0].sampling_mode:
kwargs['use_cache'] = True
kwargs['past_key_value'] = self.pipeline_model[0].cache
hidden_states, router_logits = self.orig(hidden_states, output_router_logits=True, **kwargs)
# TODO: fix unsloth gradient checkpointing when we return router logits
# router_logits = router_logits.to(torch.float32)
# router_logits = input_router_logits + (router_logits,)
# result = (hidden_states, attention_mask, cos, sin, labels, *router_logits)
result = (hidden_states, attention_mask, cos, sin, cache_position, labels)
if self.mlp_offloaded_to_cpu and not torch.is_grad_enabled():
self.move_mlp_to_cpu()
return result
def move_mlp_to_cpu(self):
if self.mlp_offloaded_to_cpu:
set_experts_data(self.orig.block_sparse_moe.experts, self.orig_data)
return
move_experts_to_device(self.orig.block_sparse_moe.experts, 'cpu', self.num_experts_to_offload)
self.mlp_offloaded_to_cpu = True
def move_mlp_to_device(self, device):
self.orig_data = move_experts_to_device(
self.orig.block_sparse_moe.experts, device, self.num_experts_to_offload
)
class Gemma3InputLayer(nn.Module):
def __init__(self, model):
super().__init__()
self._model = [model]
self.embed_tokens = model.model.embed_tokens
self.rotary_emb = model.model.rotary_emb
self.rotary_emb_local = model.model.rotary_emb_local
self.embedding_on_cpu = not self.model.train_config['full_fine_tune']
self.model.loader_util.load_state_dict_into_module(self)
@property
def model(self):
return self._model[0]
def forward(self, inputs):
past_key_values = None
cache_position = None
use_cache = self.model.sampling_mode
input_ids, attention_mask, labels = inputs[:3]
device = input_ids.device
if self.embedding_on_cpu:
self.embed_tokens.to('cpu')
input_ids = input_ids.to('cpu')
inputs_embeds = self.embed_tokens(input_ids).to(device)
if use_cache:
past_key_values = self.model.cache
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
position_ids = cache_position.unsqueeze(0)
attention_mask = self.model.model._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, None
)
if attention_mask is None:
# attention_mask can end up being None, which means use full causal attention. But with pipeline parallelism,
# we can only pass tensors between layers. So make it an empty tensor, which will later be detected by the layers
# and converted back to None. Note: this only works now because dynamic_shape=True in the pipeline engine.
attention_mask = torch.tensor([], device=device)
# Work around a very strange Deepspeed bug. The combination of PipelineModule dynamic_shape=True, attention_mask being
# an integer dtype, and pipeline_stages>2 causes training (but not eval) to hang. So cast to float, and cast back to int
# in the layer.
if torch.is_tensor(attention_mask) and not torch.is_floating_point(attention_mask):
attention_mask = attention_mask.to(inputs_embeds.dtype)
hidden_states = inputs_embeds
if self.model.model.config.model_type == 'gemma2':
normalizer = torch.tensor(self.model.model.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
position_embeddings_global_cos, position_embeddings_global_sin = self.rotary_emb(hidden_states, position_ids)
position_embeddings_local_cos, position_embeddings_local_sin = self.rotary_emb_local(hidden_states, position_ids)
output = hidden_states, attention_mask, position_embeddings_global_cos, position_embeddings_global_sin, position_embeddings_local_cos, position_embeddings_local_sin, cache_position, labels
# Deepspeed requirement. Float tensors must require grad.
for tensor in output:
if torch.is_floating_point(tensor):
tensor.requires_grad_(True)
return output
class Gemma3DecoderLayerPipe(nn.Module):
def __init__(self, pipeline_model, loader_util, orig):
super().__init__()
self.pipeline_model = [pipeline_model]
self.orig = orig
self.mlp_offloaded_to_cpu = False
self.attn_implementation = pipeline_model.config._attn_implementation
loader_util.load_state_dict_into_module(self)
def forward(self, inputs):
def move_mlp_to_cpu_hook(grad):
self.move_mlp_to_cpu()
return None
hidden_states, attention_mask, position_embeddings_global_cos, position_embeddings_global_sin, position_embeddings_local_cos, position_embeddings_local_sin, cache_position, labels = inputs
if self.mlp_offloaded_to_cpu:
if hidden_states.requires_grad:
hidden_states.register_hook(move_mlp_to_cpu_hook)
self.move_mlp_to_device(hidden_states.device)
kwargs = {}
if attention_mask.numel() == 0:
# We can't pass None between pipeline layers, so this signals that attention_mask should be None.
kwargs['attention_mask'] = None
else:
kwargs['attention_mask'] = attention_mask.to(torch.int64) if self.attn_implementation == 'flash_attention_2' else attention_mask
kwargs['position_embeddings_global'] = (position_embeddings_global_cos, position_embeddings_global_sin)
kwargs['position_embeddings_local'] = (position_embeddings_local_cos, position_embeddings_local_sin)
kwargs['cache_position'] = cache_position
if self.pipeline_model[0].sampling_mode:
kwargs['use_cache'] = True
kwargs['past_key_value'] = self.pipeline_model[0].cache
result = (
self.orig(hidden_states, **kwargs)[0],
attention_mask,
position_embeddings_global_cos,
position_embeddings_global_sin,
position_embeddings_local_cos,
position_embeddings_local_sin,
cache_position,
labels
)
if self.mlp_offloaded_to_cpu and not torch.is_grad_enabled():
self.move_mlp_to_cpu()
return result
def move_mlp_to_cpu(self):
# If it's already been moved to CPU once, just set the data to avoid a transfer.
if self.mlp_offloaded_to_cpu:
set_data(self.orig.mlp.up_proj, self.cpu_up_proj)
set_data(self.orig.mlp.down_proj, self.cpu_down_proj)
set_data(self.orig.mlp.gate_proj, self.cpu_gate_proj)
return
move_data_to_device(self.orig.mlp.up_proj, 'cpu')
move_data_to_device(self.orig.mlp.down_proj, 'cpu')
move_data_to_device(self.orig.mlp.gate_proj, 'cpu')
self.mlp_offloaded_to_cpu = True
def move_mlp_to_device(self, device):
self.cpu_up_proj = move_data_to_device(self.orig.mlp.up_proj, device)
self.cpu_down_proj = move_data_to_device(self.orig.mlp.down_proj, device)
self.cpu_gate_proj = move_data_to_device(self.orig.mlp.gate_proj, device)
class Gemma3RMSNormPipe(nn.Module):
def __init__(self, loader_util, orig):
super().__init__()
self.orig = orig
loader_util.load_state_dict_into_module(self)
def forward(self, inputs):
hidden_states, *_, labels = inputs
return self.orig(hidden_states), labels
================================================
FILE: models/models.py
================================================
import accelerate
import torch
import transformers
from models.layers import (
InputLayer,
LayerSpec,
LlamaDecoderLayerPipe,
LlamaRMSNormPipe,
MixtralDecoderLayerPipe,
MixtralOutputLayer,
OutputLayer,
Phi3DecoderLayerPipe,
Gemma3InputLayer,
Gemma3DecoderLayerPipe,
Gemma3RMSNormPipe,
)
from models.pipeline_model import PipelineModel
from utils.utils import DTYPE_MAP
DEFAULT_ATTN_IMPLEMENTATION = 'flash_attention_2'
# A little bit of inheritance and MRO trickery since LlamaForCausalLM.__init__ only takes a
# positional argument. We inherit PipelineModel first, but call LlamaForCausalLM init first,
# and make sure PipelineModel doesn't have a super().__init__() call.
class LlamaForCausalLMPipe(PipelineModel, transformers.LlamaForCausalLM):
def __init__(self, config, quantization_config):
model_config = transformers.LlamaConfig.from_pretrained(config['model'])
model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)
torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])
with accelerate.init_empty_weights():
transformers.LlamaForCausalLM.__init__(self, model_config)
PipelineModel.__init__(self, config, quantization_config, model_config)
torch.set_default_dtype(torch.float32)
def to_layer_specs(self):
embedding_relative_size = 4
embedding_on_cpu = not self.train_config['full_fine_tune']
result = [LayerSpec(InputLayer, self, _estimated_size=0 if embedding_on_cpu else embedding_relative_size)]
for block in self.model.layers:
result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block))
result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))
result.append(
LayerSpec(
OutputLayer,
self,
self.loader_util,
self.lm_head,
loss_type=self.loss_type,
focal_loss_gamma=self.focal_loss_gamma,
tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None,
_estimated_size=embedding_relative_size,
)
)
return result
class Qwen2ForCausalLMPipe(PipelineModel, transformers.Qwen2ForCausalLM):
def __init__(self, config, quantization_config):
model_config = transformers.Qwen2Config.from_pretrained(config['model'])
model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)
torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])
with accelerate.init_empty_weights():
transformers.Qwen2ForCausalLM.__init__(self, model_config)
PipelineModel.__init__(self, config, quantization_config, model_config)
torch.set_default_dtype(torch.float32)
def to_layer_specs(self):
result = [LayerSpec(InputLayer, self)]
for block in self.model.layers:
result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block))
result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))
result.append(
LayerSpec(
OutputLayer,
self,
self.loader_util,
self.lm_head,
loss_type=self.loss_type,
focal_loss_gamma=self.focal_loss_gamma,
tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None,
)
)
return result
class CohereForCausalLMPipe(PipelineModel, transformers.CohereForCausalLM):
def __init__(self, config, quantization_config):
model_config = transformers.CohereConfig.from_pretrained(config['model'])
model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)
torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])
with accelerate.init_empty_weights():
transformers.CohereForCausalLM.__init__(self, model_config)
PipelineModel.__init__(self, config, quantization_config, model_config)
torch.set_default_dtype(torch.float32)
def to_layer_specs(self):
# the embedding table for this model is huge; load balance it better with some heuristics
embedding_relative_size = 4
embedding_on_cpu = not self.train_config['full_fine_tune']
result = [LayerSpec(InputLayer, self, _estimated_size=1 if embedding_on_cpu else embedding_relative_size)]
for block in self.model.layers:
result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block))
result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))
result.append(
LayerSpec(
OutputLayer,
self,
self.loader_util,
self.lm_head,
logit_scale=self.logit_scale,
loss_type=self.loss_type,
focal_loss_gamma=self.focal_loss_gamma,
tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None,
_estimated_size=embedding_relative_size,
)
)
return result
class Phi3ForCausalLMPipe(PipelineModel, transformers.Phi3ForCausalLM):
def __init__(self, config, quantization_config):
model_config = transformers.Phi3Config.from_pretrained(config['model'])
model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)
torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])
with accelerate.init_empty_weights():
transformers.Phi3ForCausalLM.__init__(self, model_config)
PipelineModel.__init__(self, config, quantization_config, model_config)
torch.set_default_dtype(torch.float32)
def to_layer_specs(self):
result = [LayerSpec(InputLayer, self)]
for block in self.model.layers:
result.append(LayerSpec(Phi3DecoderLayerPipe, self.loader_util, block))
result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))
result.append(
LayerSpec(
OutputLayer,
self,
self.loader_util,
self.lm_head,
loss_type=self.loss_type,
focal_loss_gamma=self.focal_loss_gamma,
)
)
return result
class Gemma2ForCausalLMPipe(PipelineModel, transformers.Gemma2ForCausalLM):
def __init__(self, config, quantization_config):
model_config = transformers.Gemma2Config.from_pretrained(config['model'])
model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)
torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])
with accelerate.init_empty_weights():
transformers.Gemma2ForCausalLM.__init__(self, model_config)
PipelineModel.__init__(self, config, quantization_config, model_config)
torch.set_default_dtype(torch.float32)
def to_layer_specs(self):
# the embedding table for this model is huge; load balance it better with some heuristics
# this value optimized for LoRA, pipeline_stages=2
embedding_relative_size = 8
embedding_on_cpu = not self.train_config['full_fine_tune']
result = [LayerSpec(InputLayer, self, _estimated_size=1 if embedding_on_cpu else embedding_relative_size)]
for block in self.model.layers:
result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block))
result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))
result.append(
LayerSpec(
OutputLayer,
self,
self.loader_util,
self.lm_head,
loss_type=self.loss_type,
focal_loss_gamma=self.focal_loss_gamma,
tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None,
logit_softcapping=self.config.final_logit_softcapping,
_estimated_size=embedding_relative_size,
)
)
return result
class MistralForCausalLMPipe(PipelineModel, transformers.MistralForCausalLM):
def __init__(self, config, quantization_config):
model_config = transformers.MistralConfig.from_pretrained(config['model'])
model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)
torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])
with accelerate.init_empty_weights():
transformers.MistralForCausalLM.__init__(self, model_config)
PipelineModel.__init__(self, config, quantization_config, model_config)
torch.set_default_dtype(torch.float32)
def to_layer_specs(self):
result = [LayerSpec(InputLayer, self)]
for block in self.model.layers:
result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block))
result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))
result.append(
LayerSpec(
OutputLayer,
self,
self.loader_util,
self.lm_head,
loss_type=self.loss_type,
focal_loss_gamma=self.focal_loss_gamma,
)
)
return result
class MixtralForCausalLMPipe(PipelineModel, transformers.MixtralForCausalLM):
def __init__(self, config, quantization_config):
model_config = transformers.MixtralConfig.from_pretrained(config['model'])
model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)
torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])
with accelerate.init_empty_weights():
transformers.MixtralForCausalLM.__init__(self, model_config)
PipelineModel.__init__(self, config, quantization_config, model_config)
torch.set_default_dtype(torch.float32)
self.load_balancing_loss_coef = config.get('load_balancing_loss_coef', None)
self.num_experts_to_offload = self.num_experts
if 'offload_mlp_to_cpu' in config and isinstance(config['offload_mlp_to_cpu'], int):
self.num_experts_to_offload = config['offload_mlp_to_cpu']
def to_layer_specs(self):
result = [LayerSpec(InputLayer, self)]
for block in self.model.layers:
result.append(LayerSpec(MixtralDecoderLayerPipe, self, self.loader_util, block))
result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))
result.append(
LayerSpec(
MixtralOutputLayer,
self,
self.loader_util,
self.lm_head,
load_balancing_loss_coef=self.load_balancing_loss_coef,
num_experts=self.num_experts,
num_experts_per_tok=self.num_experts_per_tok,
loss_type=self.loss_type,
focal_loss_gamma=self.focal_loss_gamma,
)
)
return result
class Gemma3ForCausalLMPipe(PipelineModel, transformers.Gemma3ForCausalLM):
def __init__(self, config, quantization_config):
model_config = transformers.Gemma3TextConfig.from_pretrained(config['model'])
model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)
torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])
with accelerate.init_empty_weights():
transformers.Gemma3ForCausalLM.__init__(self, model_config)
PipelineModel.__init__(self, config, quantization_config, model_config)
torch.set_default_dtype(torch.float32)
def to_layer_specs(self):
# the embedding table for this model is huge; load balance it better with some heuristics
# this value optimized for LoRA, pipeline_stages=2
embedding_relative_size = 8
embedding_on_cpu = not self.train_config['full_fine_tune']
result = [LayerSpec(Gemma3InputLayer, self, _estimated_size=1 if embedding_on_cpu else embedding_relative_size)]
for block in self.model.layers:
result.append(LayerSpec(Gemma3DecoderLayerPipe, self, self.loader_util, block))
result.append(LayerSpec(Gemma3RMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))
result.append(
LayerSpec(
OutputLayer,
self,
self.loader_util,
self.lm_head,
loss_type=self.loss_type,
focal_loss_gamma=self.focal_loss_gamma,
tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None,
logit_softcapping=self.config.final_logit_softcapping,
_estimated_size=embedding_relative_size,
)
)
return result
class Cohere2ForCausalLMPipe(PipelineModel, transformers.Cohere2ForCausalLM):
def __init__(self, config, quantization_config):
model_config = transformers.Cohere2Config.from_pretrained(config['model'])
model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)
torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])
with accelerate.init_empty_weights():
transformers.Cohere2ForCausalLM.__init__(self, model_config)
PipelineModel.__init__(self, config, quantization_config, model_config)
torch.set_default_dtype(torch.float32)
def to_layer_specs(self):
# the embedding table for this model is huge; load balance it better with some heuristics
embedding_relative_size = 8
embedding_on_cpu = not self.train_config['full_fine_tune']
result = [LayerSpec(InputLayer, self, _estimated_size=2 if embedding_on_cpu else embedding_relative_size)]
for block in self.model.layers:
result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block))
result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))
result.append(
LayerSpec(
OutputLayer,
self,
self.loader_util,
self.lm_head,
logit_scale=self.logit_scale,
loss_type=self.loss_type,
focal_loss_gamma=self.focal_loss_gamma,
tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None,
_estimated_size=embedding_relative_size,
)
)
return result
================================================
FILE: models/pipeline_model.py
================================================
import os
from collections import defaultdict
from inspect import signature
import re
import accelerate
import bitsandbytes as bnb
import transformers
from deepspeed.accelerator import get_accelerator
from hqq.core import quantize as hqq_quantize
from torch import nn
from transformers.integrations import get_keys_to_not_convert
from accelerate.utils import set_module_tensor_to_device
import utils.hqq_utils as hqq_utils
from utils.utils import is_main_process
LANGUAGE_MODEL_WEIGHT_PREFIX_REGEX = r'^language_model\.'
class PipelineModel(nn.Module):
def __init__(self, config, quantization_config, model_config):
if config['full_fine_tune'] and model_config.tie_word_embeddings:
raise NotImplementedError('FFT is not supported for models with tied embeddings')
self.train_config = config
self.model_config = model_config
self.modules_to_not_quantize = get_keys_to_not_convert(self)
self.loader_util = LoaderUtil(config['model'], quantization_config, self.modules_to_not_quantize)
self.loss_type = config.get('loss_type', 'cross_entropy_loss').lower()
if rl_config := config.get('rl', None):
self.loss_type = rl_config['method']
self.focal_loss_gamma = config.get('focal_loss_gamma', 0)
if self.focal_loss_gamma > 0 and is_main_process():
print(f"Optimizing using '{self.loss_type}' with gamma={self.focal_loss_gamma}")
self.dpo_reference_mode = False
self.sampling_mode = False
for name, p in self.named_parameters():
p.original_name = name
# need to override this method
def to_layer_specs(self):
raise NotImplementedError()
def set_dpo_reference_mode(self, dpo_reference_mode):
self.dpo_reference_mode = dpo_reference_mode
def set_sampling_mode(self, sampling_mode):
self.sampling_mode = sampling_mode
# Reset cache when sampling mode is modified. This ensures it's initialized and also clears memory at the end.
self.cache_dict = defaultdict(transformers.DynamicCache)
# We could try to use static cache at some point. During early testing with relatively short sequence lengths,
# it was the same sampling speed as DynamicCache. Note: will need to pass cache_position in transformer layer
# if using StaticCache.
# def make_static_cache():
# return transformers.StaticCache(
# self.model_config,
# max_batch_size=1,
# max_cache_len=1024,
# device='cuda',
# dtype=self.dtype,
# )
# self.cache_dict = defaultdict(make_static_cache)
def set_cache(self, micro_batch_id):
self.cache = self.cache_dict[micro_batch_id]
def _partial_module_name_match(full_name, list_to_match):
return any(key in full_name for key in list_to_match)
def _replace_with_quantized_linear(parent_modules_map, name, full_name, quantization_config):
if isinstance(quantization_config, transformers.BitsAndBytesConfig):
_replace_with_bnb_linear(parent_modules_map, name, full_name, quantization_config)
elif isinstance(quantization_config, hqq_utils.CustomHQQConfig):
_replace_with_hqq_linear(parent_modules_map, name, full_name, quantization_config)
else:
raise NotImplementedError(f'Quantization config not implemented: {quantization_config}')
def _replace_with_bnb_linear(parent_modules_map, name, full_name, quantization_config):
"""Replace a Linear layer with a BNB quantized version."""
if quantization_config.llm_int8_skip_modules is not None and _partial_module_name_match(
full_name, quantization_config.llm_int8_skip_modules
):
return
module = parent_modules_map[name]
with accelerate.init_empty_weights():
if isinstance(module, nn.Conv1d):
in_features, out_features = module.weight.shape
else:
in_features = module.in_features
out_features = module.out_features
if quantization_config.quantization_method() == 'llm_int8':
parent_modules_map[name] = bnb.nn.Linear8bitLt(
in_features,
out_features,
module.bias is not None,
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
threshold=quantization_config.llm_int8_threshold,
)
else:
extra_kwargs = (
{'quant_storage': quantization_config.bnb_4bit_quant_storage}
if 'quant_storage' in list(signature(bnb.nn.Linear4bit).parameters)
else {}
)
parent_modules_map[name] = bnb.nn.Linear4bit(
in_features,
out_features,
module.bias is not None,
quantization_config.bnb_4bit_compute_dtype,
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type,
**extra_kwargs,
)
# Store the module class in case we need to transpose the weight later
parent_modules_map[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
parent_modules_map[name].requires_grad_(False)
def _replace_with_hqq_linear(parent_modules_map, name, full_name, quantization_config):
"""Replace a Linear layer with a HQQ quantized version."""
if _partial_module_name_match(full_name, quantization_config.skip_modules):
return
module = parent_modules_map[name]
quant_config_dict = quantization_config.get_dict(full_name)
hqq_linear = hqq_quantize.HQQLinear(
module,
quant_config=quant_config_dict,
compute_dtype=quantization_config.compute_dtype,
device=module.weight.device,
initialize=True,
del_orig=True,
)
# Quantization itself uses a decent amount of VRAM. Temporarily move each quantized parameter to the CPU as we
# finish, so the quant process doesn't OOM. Deepspeed will move everything to the correct device later.
hqq_linear.W_q.data = hqq_linear.W_q.data.to('cpu')
# Store the module class in case we need to transpose the weight later
hqq_linear.source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
hqq_linear.requires_grad_(False)
parent_modules_map[name] = hqq_linear
# modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/bitsandbytes.py
def _recursively_replace_with_quantized_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
):
"""
Returns the converted model and a boolean that indicates if the conversion has been successful or not.
"""
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if (isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d)) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = '.'.join(current_key_name)
if not any(
(key + '.' in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
_replace_with_quantized_linear(model._modules, name, current_key_name_str, quantization_config)
# copy over the original_name attribute we added earlier (needed for loading weights)
for orig_name, orig_p in module.named_parameters():
if hasattr(orig_p, 'original_name'):
for new_name, new_p in model._modules[name].named_parameters():
if new_name == orig_name:
new_p.original_name = orig_p.original_name
if len(list(module.children())) > 0:
_recursively_replace_with_quantized_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
)
# Remove the last key for recursion
current_key_name.pop(-1)
class LoaderUtil:
def __init__(self, model_path, quantization_config, modules_to_not_quantize):
self.model_path = model_path
self.quantization_config = quantization_config
self.modules_to_not_quantize = modules_to_not_quantize
self.local_rank = int(os.environ.get('LOCAL_RANK', None))
assert self.local_rank is not None
self.device = get_accelerator().device_name(self.local_rank)
index_file = os.path.join(model_path, transformers.utils.SAFE_WEIGHTS_INDEX_NAME)
if os.path.exists(index_file):
checkpoint_files, checkpoint_metadata = transformers.utils.hub.get_checkpoint_shard_files(
model_path, index_file, local_files_only=True
)
self.checkpoint_metadata = checkpoint_metadata
else:
self.checkpoint_metadata = None
self.loaded_state_dict = None
def get_partial_state_dict(self, leaf_file):
if self.loaded_state_dict is None or leaf_file != self.loaded_state_dict[0]:
print(f'loading checkpoint file {leaf_file}')
state_dict = transformers.modeling_utils.load_state_dict(os.path.join(self.model_path, leaf_file))
state_dict = {re.sub(LANGUAGE_MODEL_WEIGHT_PREFIX_REGEX, '', k): v for k, v in state_dict.items()}
self.loaded_state_dict = (leaf_file, state_dict)
return self.loaded_state_dict[1]
def maybe_quantize(self, module):
if self.quantization_config is None:
return
modules_to_not_convert = self.modules_to_not_quantize
if not isinstance(modules_to_not_convert, list):
modules_to_not_convert = [modules_to_not_convert]
_recursively_replace_with_quantized_linear(
module, modules_to_not_convert=modules_to_not_convert, quantization_config=self.quantization_config
)
# Make sure to set this or PEFT (and probably other things) will break in strange ways.
# We only need this because we do the loading and quanting ourselves.
self.is_loaded_in_4bit = True
def load_state_dict_into_module(self, module):
print(f'load params into module {type(module)}')
if isinstance(self.quantization_config, transformers.BitsAndBytesConfig):
# bnb needs to replace with quantized linear before weights are loaded
self.maybe_quantize(module)
param_renaming_map = {p.original_name: new_name for new_name, p in module.named_parameters()}
expected_keys = [p.original_name for p in module.parameters()]
# If we have any extra attributes on the parameter, loading with BNB 4bit params breaks, so delete them.
for p in module.parameters():
del p.original_name
if self.checkpoint_metadata is not None:
weight_map = self.checkpoint_metadata['weight_map']
weight_map = {re.sub(LANGUAGE_MODEL_WEIGHT_PREFIX_REGEX, '', k): v for k, v in weight_map.items()}
needed_checkpoint_files = {weight_map[key.replace('orig.', '')] for key in expected_keys}
else:
needed_checkpoint_files = ['model.safetensors']
for checkpoint_file in needed_checkpoint_files:
state_dict = self.get_partial_state_dict(checkpoint_file)
renamed_state_dict = {param_renaming_map[k]: v for k, v in state_dict.items() if k in param_renaming_map}
for name, param in module.named_parameters():
if name in renamed_state_dict:
set_module_tensor_to_device(module, name, device='cpu', value=renamed_state_dict[name])
module.to(self.device)
if not isinstance(self.quantization_config, transformers.BitsAndBytesConfig):
self.maybe_quantize(module)
================================================
FILE: pyproject.toml
================================================
[tool.black]
# Only used by `hf-doc-builder´.
line-length = 119
target-version = ['py38']
[tool.ruff]
target-version = "py38"
line-length = 119
extend-exclude = ["axolotl"]
[tool.ruff.format]
quote-style = "single"
[tool.ruff.lint]
extend-select = [
"C", # Complexity
"E", # PEP8 errors
"F", # PEP8 formatting
"I", # Import sorting
"UP", # Pyupgrade upgrades
"W", # PEP8 warnings
# "PT009", # Pytest assertions
]
ignore = [
"C901", # Function too complex
"E501", # Line length (handled by ruff-format)
"E741", # Allow the letter 'l'
"UP007", # X | Y style Unions
# "PT009", # self.assertEqual etc
]
[tool.ruff.lint.isort]
lines-after-imports = 2
known-first-party = ["peft"]
[tool.pytest]
doctest_optionflags = [
"NORMALIZE_WHITESPACE",
"ELLIPSIS",
"NUMBER",
]
[tool.pytest.ini_options]
addopts = "--cov=src/peft --cov-report=term-missing --durations=10"
markers = [
"single_gpu_tests: tests that run on a single GPU",
"multi_gpu_tests: tests that run on multiple GPUs",
"regression: whether to run regression suite test",
"bitsandbytes: select bitsandbytes integration tests"
]
================================================
FILE: requirements.txt
================================================
torch
torchvision
torchaudio
accelerate
bitsandbytes
datasets
deepspeed
packaging
peft
safetensors
scipy
sentencepiece
tensorboard
toml
jsonlines
transformers
flash-attn
pyyaml
tqdm
triton
hqq
torch-optimi
ruff
# minimum Axolotl dependencies for what we use
addict
optimum
pynvml
evaluate
wandb
numba
colorama
trl
mlflow
termcolor
fastcore
================================================
FILE: tools/convert_dpo_dataset_to_chat_format.py
================================================
# Convert a DPO dataset with prompt, chosen, rejected fields into chat format.
# Usage: python convert_dpo_dataset_to_chat_format.py hf_username/some_dataset path/to/output/directory
import os
import sys
from pathlib import Path
import datasets
dataset_path, converted_path = sys.argv[1:]
dataset = datasets.load_dataset(dataset_path)
def convert(x):
prompt = x['prompt']
chosen = [{'role': 'user', 'content': prompt}, {'role': 'assistant', 'content': x['chosen']}]
rejected = [{'role': 'user', 'content': prompt}, {'role': 'assistant', 'content': x['rejected']}]
return {'chosen': chosen, 'rejected': rejected}
num_proc = min(64, os.cpu_count())
dataset = dataset.map(convert, num_proc=num_proc)
for name, split in dataset.items():
filepath = Path(converted_path) / f'{name}.json'
split.to_json(filepath)
================================================
FILE: tools/convert_ds_checkpoint_to_lora.py
================================================
# Very hacky script to convert pipeline parallel Deepspeed checkpoints into a saved lora model.
# I originally wrote this because I screwed up the lora model saving initially, and needed a
# way to turn the training checkpoints into saved lora models to test them.
import os.path
import re
from glob import glob
import torch
def convert_ds_checkpoint_to_lora(ds_checkpoint_dir, lora_output_dir):
layer_checkpoint_files = glob(os.path.join(ds_checkpoint_dir, 'layer_*-model_states.pt'))
combined_state_dict = {}
for path in layer_checkpoint_files:
match = re.fullmatch('layer_(.+)-model_states.pt', os.path.basename(path))
layer_idx = int(match.group(1)) - 2
state_dict = torch.load(path)
for name, weight in state_dict.items():
converted_name = name.replace('orig', f'base_model.model.model.layers.{layer_idx}').replace('.default', '')
combined_state_dict[converted_name] = weight
os.makedirs(lora_output_dir, exist_ok=True)
torch.save(combined_state_dict, os.path.join(lora_output_dir, 'adapter_model.bin'))
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--input')
parser.add_argument('--output')
args = parser.parse_args()
convert_ds_checkpoint_to_lora(args.input, args.output)
================================================
FILE: tools/merge_lora.py
================================================
# Usage: python merge_lora.py input_path lora_path output_path
# Output path is created if it doesn't exist
import argparse
import os
import re
import shutil
from pathlib import Path
import safetensors
import torch
from tqdm import tqdm
import peft
parser = argparse.ArgumentParser()
parser.add_argument('input_path', type=str, help='The path to the input directory.')
parser.add_argument('lora_path', type=str, help='The path to the LoRA directory.')
parser.add_argument('output_path', type=str, help='The path to the output directory.')
parser.add_argument('--no-gpu', action='store_true', help='Use CPU for merging.')
args = parser.parse_args()
input_path, lora_path, output_path = Path(args.input_path), Path(args.lora_path), Path(args.output_path)
os.makedirs(output_path, exist_ok=True)
lora_config = peft.LoraConfig.from_json_file(lora_path / 'adapter_config.json')
scale = lora_config['lora_alpha'] / lora_config['r']
device = 'cpu' if args.no_gpu else 'cuda'
print('Loading LoRA model...')
# Check if we have adapter_model.bin or adapter_model.safetensors
if (lora_path / 'adapter_model.safetensors').exists():
lora_state = safetensors.torch.load_file(lora_path / 'adapter_model.safetensors')
if not args.no_gpu:
# Move mapped entries to cuda
for key, value in tqdm(lora_state.items()):
lora_state[key] = value.to('cuda')
else:
lora_state = torch.load(lora_path / 'adapter_model.bin', map_location=device)
def find_lora_weights(key):
lora_A = None
lora_B = None
for lora_key, lora_weight in lora_state.items():
if key.strip('.weight') in lora_key:
if 'lora_A' in lora_key:
lora_A = lora_weight
elif 'lora_B' in lora_key:
lora_B = lora_weight
else:
raise RuntimeError()
assert not ((lora_A is None) ^ (lora_B is None))
return lora_A, lora_B
shards = []
for shard in input_path.glob('model*.safetensors'):
shards.append(shard)
print('Copying unmergable files to output')
for filepath in input_path.glob('*'):
if filepath in shards:
continue
filepath = Path(filepath)
if filepath.is_dir():
continue
if filepath.suffix == '.gguf':
# Skip unrelated stray quantizations
continue
if filepath.suffix == '.safetensors':
# Consolidated, possibly
continue
print(f'copying {filepath.name} to output')
shutil.copy(filepath, output_path)
print('Merging and copying state_dict to output')
found = 0
for shard in (pbar := tqdm(shards)):
tensors = {}
with safetensors.safe_open(shard, framework='pt', device=device) as f:
metadata = f.metadata()
for key in f.keys():
lora_key = re.sub(r'^language_model\.', '', key)
tensor = f.get_tensor(key)
lora_A, lora_B = find_lora_weights(lora_key)
if lora_A is not None:
found += 1
pbar.set_description(f'found lora weights for {key}: {lora_A.size()}, {lora_B.size()}')
old_type = tensor.dtype
tensor = tensor.to(torch.float32)
tensor += scale * lora_B.to(torch.float32) @ lora_A.to(torch.float32)
tensor = tensor.to(old_type)
tensors[key] = tensor
safetensors.torch.save_file(tensors, output_path / shard.name, metadata=metadata)
print(f"Applied LoRA to {found} tensors.")
================================================
FILE: tools/test_sampling.py
================================================
# deepspeed --num_gpus=1 --module tools.test_sampling --config ~/code/qlora-pipe-configs/config_8b_dpo.toml
import argparse
import json
import os.path
import bitsandbytes
import deepspeed
import toml
import transformers
import utils.engine as engine
from train import load_pipeline_model_with_lora
from utils.utils import DTYPE_MAP, is_main_process
PROMPT_FORMAT = """<|start_header_id|>user<|end_header_id|>
{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
PROMPTS = [
'Where is Popeye Village located?',
'What is the name of Sweden in Swedish?',
]
parser = argparse.ArgumentParser()
parser.add_argument('--config', help='Path to TOML configuration file.')
parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
if __name__ == '__main__':
with open(args.config) as f:
config = toml.load(f)
config['full_fine_tune'] = True
if hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None:
# engine.initialize() will load deepspeed config from args
ds_config = None
else:
# The necessary ds_config fields are taken from the TOML config file.
ds_config = {
'train_micro_batch_size_per_gpu': config.get('micro_batch_size_per_gpu', 1),
'gradient_accumulation_steps': config.get('gradient_accumulation_steps', 1),
'gradient_clipping': config.get('gradient_clipping', 1.0),
'steps_per_print': config.get('steps_per_print', 1),
}
deepspeed.init_distributed()
with open(os.path.join(config['model'], 'config.json')) as f:
model_config = json.load(f)
model_type = model_config.get('model_type', 'llama')
tokenizer = transformers.AutoTokenizer.from_pretrained(
config['model'],
local_files_only=True,
model_max_length=int(1e30),
padding_side='left',
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Ugly hack so we can move quantized models from GPU to CPU, and back to GPU again without triggering quantization a second time.
bnb_cuda_old = bitsandbytes.nn.modules.Params4bit.cuda
def bnb_cuda_hijack(self, device):
if getattr(self, 'already_quantized', False):
self.data = self.data.to(device)
self.quant_state.to(device)
return self
self.already_quantized = True
return bnb_cuda_old(self, device)
bitsandbytes.nn.modules.Params4bit.cuda = bnb_cuda_hijack
pipeline_model, lora_model, lora_config = load_pipeline_model_with_lora(config, model_type)
kwargs = {}
if sampling_settings := config.get('sampling', None):
for k, v in sampling_settings.items():
kwargs['sampling_' + k] = v
model_engine, _ = engine.initialize(
args=args,
model=pipeline_model,
lora_model=lora_model,
config=ds_config,
tokenizer=tokenizer,
**kwargs,
)
weight_dtype = DTYPE_MAP[config.get('lora_weight_dtype', config.get('model_weight_dtype', 'float32'))]
model_engine.communication_data_type = weight_dtype
prompts = [PROMPT_FORMAT.format(prompt) for prompt in PROMPTS]
#prompts = [[PROMPT_FORMAT.format(prompt) for prompt in PROMPTS]]
outputs = model_engine.sample_batch(prompts, max_new_tokens=500)
if is_main_process():
for text in outputs:
print(text)
print('-' * 80)
================================================
FILE: train.py
================================================
import argparse
import glob
import itertools
import json
import os
import shutil
import time
from contextlib import contextmanager
from datetime import datetime, timezone
import bitsandbytes
import deepspeed
import toml
import torch
import transformers
from deepspeed.runtime.pipe.module import LayerSpec
from hqq.core import quantize as hqq_quantize
from torch.utils.tensorboard import SummaryWriter
import utils.dataloader as dataloader
import utils.engine as engine
import utils.hqq_utils as hqq_utils
import models.models as models
import utils.unsloth_utils as unsloth_utils
from utils.dataset_utils import load_datasets
from peft import LoraConfig, get_peft_model
from peft.optimizers import create_loraplus_optimizer
from utils.saver import Saver
from utils.utils import DTYPE_MAP, is_main_process
parser = argparse.ArgumentParser()
parser.add_argument('--config', help='Path to TOML configuration file.')
parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')
parser.add_argument('--debug_dataset', type=int, help='print out this many training examples and then quit')
parser.add_argument(
'--resume_from_checkpoint',
action='store_true',
default=None,
help='resume training from the most recent checkpoint',
)
parser.add_argument('--no_quantiles', action='store_true', help='suppress output of quantile metrics')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
def print_model_info(model):
if not is_main_process():
return
print(model)
for name, module in model.named_modules():
print(f'{type(module)}: {name}')
for pname, p in module.named_parameters(recurse=False):
print(pname)
print(p.dtype)
print(p.device)
print(p.requires_grad)
print()
def set_config_defaults(config):
config['full_fine_tune'] = config.get('full_fine_tune', False)
config['load_in_4bit'] = config.get('load_in_4bit', False)
def get_most_recent_run_dir(output_dir):
return sorted(glob.glob(os.path.join(output_dir, '*')))[-1]
def write_metrics(tb_writer, prefix, metrics, step):
loss = metrics[0].mean().item()
tb_writer.add_scalar(f'{prefix}/loss', loss, step)
if len(metrics) > 1:
losses = metrics[1].view(-1)
positive_losses = losses > 0
tb_writer.add_histogram(f'{prefix}/log_loss_hist', torch.log(losses[positive_losses]), step)
if not args.no_quantiles:
sorted_losses, sorted_losses_idx = torch.sort(losses)
quantiles = torch.tensor(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.96, 0.97, 0.98, 0.99, 0.999], dtype=torch.float32
).to(losses.device)
quantiles_idx = [int(len(losses) * quantile) for quantile in quantiles]
loss_quantiles = [sorted_losses[i] for i in quantiles_idx]
for quantile, value in zip(quantiles, loss_quantiles):
tb_writer.add_scalar(f'{prefix}/loss_quantile_{quantile:.3f}', value, step)
if len(metrics) > 2:
hidden_norm_avg = metrics[2].mean().item()
tb_writer.add_scalar(f'{prefix}/hidden_norm_avg', hidden_norm_avg, step)
hidden_state_norms = metrics[2].view(-1)
tb_writer.add_histogram(f'{prefix}/hidden_norm_hist', hidden_state_norms, step)
if len(metrics) > 3:
entropy = metrics[3].view(-1)
tb_writer.add_scalar(f'{prefix}/entropy', entropy.mean().item(), step)
if not args.no_quantiles:
assert entropy.size() == losses.size(), (entropy.size(), losses.size())
sorted_entropy = entropy[sorted_losses_idx]
entropy_quantiles = []
for i, j in itertools.zip_longest(quantiles_idx, quantiles_idx[1:]):
entropy_quantiles.append(sorted_entropy[i:j].mean())
for quantile, value in zip(quantiles, entropy_quantiles):
tb_writer.add_scalar(f'{prefix}/entropy_quantile_{quantile:.3f}', value, step)
if len(metrics) > 4:
normalised_entropy = metrics[4].view(-1)
tb_writer.add_scalar(f'{prefix}/normalised_entropy', normalised_entropy.mean().item(), step)
if not args.no_quantiles:
assert normalised_entropy.size() == losses.size()
sorted_normalised_entropy = normalised_entropy[sorted_losses_idx]
normalised_entropy_quantiles = []
for i, j in itertools.zip_longest(quantiles_idx, quantiles_idx[1:]):
normalised_entropy_quantiles.append(sorted_normalised_entropy[i:j].mean())
for quantile, value in zip(quantiles, normalised_entropy_quantiles):
tb_writer.add_scalar(f'{prefix}/normalised_entropy_quantile_{quantile:.3f}', value, step)
if len(metrics) > 5:
log_likelihood = metrics[5].mean()
tb_writer.add_scalar(f'{prefix}/log_likelihood', log_likelihood.item(), step)
likelihood = torch.exp(-log_likelihood).item()
tb_writer.add_scalar(f'{prefix}/likelihood', likelihood, step)
perplexity = torch.exp(log_likelihood).item()
tb_writer.add_scalar(f'{prefix}/perplexity', perplexity, step)
if len(metrics) > 6:
mcfaddens_pseudo_r2 = metrics[6].mean()
tb_writer.add_scalar(f'{prefix}/mcfaddens_pseudo_r2', mcfaddens_pseudo_r2.item(), step)
if len(metrics) > 7:
tb_writer.add_scalar(f'{prefix}/top1_accuracy', metrics[7].mean().item(), step)
tb_writer.add_scalar(f'{prefix}/top5_accuracy', metrics[8].mean().item(), step)
tb_writer.add_scalar(f'{prefix}/top20_accuracy', metrics[9].mean().item(), step)
if len(metrics) > 10:
tb_writer.add_scalar(f'{prefix}/load_balancing_loss', metrics[10].mean().item(), step)
if len(metrics) > 11:
tb_writer.add_scalar(f'{prefix}/alternate_load_balancing_loss', metrics[11].mean().item(), step)
return loss
def evaluate_single(model_engine, name, eval_dataloader, tb_writer, step, eval_gradient_accumulation_steps):
orig_micro_batches = model_engine.micro_batches
model_engine.micro_batches = eval_gradient_accumulation_steps
iterator = iter(eval_dataloader)
all_metrics = None
while True:
metrics = model_engine.eval_batch(iterator)
eval_dataloader.sync_epoch()
if all_metrics is None:
all_metrics = [[] for _ in range(len(metrics))]
if eval_dataloader.epoch == 2:
break
for i, metric in enumerate(metrics):
all_metrics[i].append(metric)
eval_dataloader.reset()
model_engine.micro_batches = orig_micro_batches
eval_metrics = [torch.cat(metric_list) for metric_list in all_metrics]
loss = None
if is_main_process():
loss = write_metrics(tb_writer, f'eval/{name}', eval_metrics, step)
return loss
def evaluate(model_engine, eval_dataloaders, tb_writer, step, eval_gradient_accumulation_steps):
if is_main_process():
print('Running eval')
start = time.time()
loss = []
for name, eval_dataloader in eval_dataloaders.items():
loss_or_none = evaluate_single(
model_engine, name, eval_dataloader, tb_writer, step, eval_gradient_accumulation_steps
)
if loss_or_none is not None:
loss.append(loss_or_none)
duration = time.time() - start
if is_main_process():
tb_writer.add_scalar('eval/eval_time_sec', duration, step)
return sum(loss) / len(loss) if len(loss) > 0 else None
def apply_max_norm_regularization(model, config):
# modifed from https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
A_keys = []
B_keys = []
norms = []
keys_scaled = 0
lora_scale = config['lora_alpha'] / config['lora_rank']
state_dict = model.state_dict()
for key in state_dict.keys():
if 'lora_A' in key:
A_keys.append(key)
B_keys.append(key.replace('lora_A', 'lora_B'))
for i in range(len(A_keys)):
A = state_dict[A_keys[i]]
B = state_dict[B_keys[i]]
W = B @ A
W *= lora_scale
if 'scale_weight_norms' in config:
max_norm = config['scale_weight_norms']
norm = W.norm().clamp(min=max_norm / 2)
desired = torch.clamp(norm, max=max_norm)
ratio = desired.cpu() / norm.cpu()
sqrt_ratio = ratio**0.5
if ratio != 1:
keys_scaled += 1
state_dict[A_keys[i]] *= sqrt_ratio
state_dict[B_keys[i]] *= sqrt_ratio
else:
ratio = 1.0
scalednorm = W.norm() * ratio
norms.append(scalednorm.item())
if len(norms) > 0:
norms = torch.tensor(norms, dtype=torch.float32)
if torch.any(torch.isnan(norms)):
raise RuntimeError('NaN detected in norms, probably some/all weights are NaN')
avg_norm = sum(norms) / len(norms)
max_norm = max(norms)
else:
avg_norm = 0
max_norm = 0
return keys_scaled, avg_norm, max_norm, norms
def parse_layers_to_transform(spec):
parts = spec.split(',')
result = []
for part in parts:
start, stop = part.split(':')
result.extend(range(int(start), int(stop) + 1))
return result
@contextmanager
def one_at_a_time():
for i in range(int(os.environ['LOCAL_SIZE'])):
if i == int(os.environ['LOCAL_RANK']):
yield
deepspeed.comm.barrier()
def load_pipeline_model_with_lora(config, model_type):
full_fine_tune = config['full_fine_tune']
if config.get('quantization', None):
assert not full_fine_tune
no_quant_modules = ['lm_head']
if model_type == 'mixtral':
# the expert routing weights are tiny and probably important, don't quantize
no_quant_modules.append('gate')
if bnb_quant_config := config['quantization'].get('bnb', None):
if bnb_compute_dtype := bnb_quant_config.get('bnb_4bit_compute_dtype', None):
bnb_quant_config['bnb_4bit_compute_dtype'] = DTYPE_MAP[bnb_compute_dtype]
if 'bnb_4bit_quant_type' not in bnb_quant_config:
# Always want to default to nf4 if not specified.
bnb_quant_config['bnb_4bit_quant_type'] = 'nf4'
if llm_int8_skip_modules := bnb_quant_config.get('llm_int8_skip_modules', None):
no_quant_modules.extend(llm_int8_skip_modules)
no_quant_modules = list(set(no_quant_modules))
bnb_quant_config['llm_int8_skip_modules'] = no_quant_modules
quantization_config = transformers.BitsAndBytesConfig(**bnb_quant_config)
elif hqq_quant_config := config['quantization'].get('hqq', None):
quantization_config = hqq_utils.CustomHQQConfig(**hqq_quant_config)
# Use ATEN backend if possible, else PYTORCH. PYTORCH_COMPILE was only a tiny bit faster, and requires triton nightly.
hqq_quantize.HQQLinear.set_backend(
hqq_quantize.HQQBackend.ATEN if quantization_config.use_aten() else hqq_quantize.HQQBackend.PYTORCH
)
else:
raise NotImplementedError('Invalid quantization config')
if is_main_process():
print(f'Quantization config: {quantization_config}')
else:
quantization_config = None
if model_type == 'llama':
model = models.LlamaForCausalLMPipe(config, quantization_config=quantization_config)
elif model_type == 'mixtral':
model = models.MixtralForCausalLMPipe(config, quantization_config=quantization_config)
elif model_type == 'qwen2':
model = models.Qwen2ForCausalLMPipe(config, quantization_config=quantization_config)
elif model_type == 'cohere':
model = models.CohereForCausalLMPipe(config, quantization_config=quantization_config)
elif model_type == 'phi3':
model = models.Phi3ForCausalLMPipe(config, quantization_config=quantization_config)
elif model_type == 'gemma2':
model = models.Gemma2ForCausalLMPipe(config, quantization_config=quantization_config)
elif model_type == 'mistral' or model_type == 'mistral3':
model = models.MistralForCausalLMPipe(config, quantization_config=quantization_config)
elif model_type == 'gemma3':
model = models.Gemma3ForCausalLMPipe(config, quantization_config=quantization_config)
elif model_type == 'cohere2':
model = models.Cohere2ForCausalLMPipe(config, quantization_config=quantization_config)
else:
raise NotImplementedError(f'model_type {model_type} is not implemented')
# CAREFUL! The "primary" layers of the model have to have 'decoderlayer' in them for
# activation checkpointing to automatically work correctly.
layers = model.to_layer_specs()
checkpointable_layers = set()
for layer in layers:
if isinstance(layer, LayerSpec) and 'decoderlayer' in layer.typename.__name__.lower():
checkpointable_layers.add(layer.typename.__name__)
checkpointable_layers = list(checkpointable_layers)
partition_method = 'estimated_size'
if config['activation_checkpointing']:
# NOTE: must use a reentrant checkpointing function for MLP offloading to work.
if config['activation_checkpointing'] == 'unsloth':
checkpoint_func = unsloth_utils.unsloth_checkpoint
elif config['activation_checkpointing'] == 'cpu':
deepspeed.checkpointing.configure(None, checkpoint_in_cpu=True)
checkpoint_func = deepspeed.checkpointing.checkpoint
else:
checkpoint_func = deepspeed.checkpointing.checkpoint
pipeline_model = engine.CustomPipelineModule(
layers=layers,
num_stages=config['pipeline_stages'],
activation_checkpoint_interval=1,
checkpointable_layers=checkpointable_layers,
activation_checkpoint_func=checkpoint_func,
partition_method=partition_method,
use_column_major_topology=config.get('use_column_major_topology', False),
model=model,
dynamic_shape=True,
)
else:
pipeline_model = engine.CustomPipelineModule(
layers=layers,
num_stages=config['pipeline_stages'],
partition_method=partition_method,
use_column_major_topology=config.get('use_column_major_topology', False),
)
target_modules = config['target_modules'] if 'target_modules' in config else 'all-linear'
if full_fine_tune:
lora_model = None
lora_config = None
for name, p in model.named_parameters():
p.original_name = name
if isinstance(target_modules, list):
for name, p in model.named_parameters():
if not any(target in name for target in config['target_modules']):
p.requires_grad = False
print(f'not training {name} because it is not present in target_modules')
else:
layers_to_transform = (
parse_layers_to_transform(config['layers_to_transform']) if 'layers_to_transform' in config else None
)
lora_config = LoraConfig(
r=config['lora_rank'],
lora_alpha=config['lora_alpha'],
target_modules=target_modules,
modules_to_save=config['modules_to_save'] if 'modules_to_save' in config else [],
lora_dropout=config['lora_dropout'] if 'lora_dropout' in config else 0,
layers_to_transform=layers_to_transform,
bias='none',
task_type='CAUSAL_LM',
use_dora=config.get('use_dora', False),
)
lora_model = get_peft_model(model, lora_config)
# If the underlying weights are floats, the lora weights have already been
# cast to the same dtype, so we need to change the dtype here.
for p in lora_model.parameters():
if p.requires_grad:
p.data = p.data.to(DTYPE_MAP[config.get('lora_weight_dtype', 'float32')])
lora_model.model.config.use_cache = False
for name, p in lora_model.named_parameters():
p.original_name = name
return pipeline_model, lora_model, lora_config
if __name__ == '__main__':
# TODO: if resuming from checkpoint, probably should read all config files from checkpoint dir
# rather than assume they are unchanged on the command line
with open(args.config) as f:
config = toml.load(f)
set_config_defaults(config)
if hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None:
# engine.initialize() will load deepspeed config from args
ds_config = None
else:
# The necessary ds_config fields are taken from the TOML config file.
ds_config = {
'train_micro_batch_size_per_gpu': config.get('micro_batch_size_per_gpu', 1),
'gradient_accumulation_steps': config.get('gradient_accumulation_steps', 1),
'gradient_clipping': config.get('gradient_clipping', 1.0),
'steps_per_print': config.get('steps_per_print', 1),
}
resume_from_checkpoint = (
args.resume_from_checkpoint
if args.resume_from_checkpoint is not None
else config['resume_from_checkpoint']
if 'resume_from_checkpoint' in config
else False
)
deepspeed.init_distributed()
with open(os.path.join(config['model'], 'config.json')) as f:
model_config = json.load(f)
model_type = model_config.get('model_type', 'llama')
# Pad on left to support training techniques that involve sampling from the model.
tokenizer = transformers.AutoTokenizer.from_pretrained(
config['model'], local_files_only=True, model_max_length=int(1e30), padding_side='left'
)
# TODO: do we want to do this with cohere models? By default the EOS token is <|END_OF_TURN_TOKEN|>
# if model_type == 'cohere':
# tokenizer.eos_token = '<EOS_TOKEN>'
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
train_data, eval_data_map = load_datasets(config, tokenizer)
if args.debug_dataset:
if is_main_process():
for i, item in enumerate(iter(train_data)):
print('input_ids:')
print(item['input_ids'])
print('decoded input_ids:')
print(tokenizer.decode(item['input_ids']))
print('attention_mask:')
print(item['attention_mask'])
print('labels:')
print(item['labels'])
if 'rejected_input_ids' in item:
print('rejected_input_ids:')
print(item['rejected_input_ids'])
print('decoded rejected_input_ids:')
print(tokenizer.decode(item['rejected_input_ids']))
print('rejected_attention_mask:')
print(item['rejected_attention_mask'])
print('rejected_labels:')
print(item['rejected_labels'])
print('-' * 80)
if i >= args.debug_dataset - 1:
break
quit()
# for testing
# train_data = train_data.select(list(range(100)))
# if this is a new run, create a new dir for it
if not resume_from_checkpoint and is_main_process():
run_dir = os.path.join(config['output_dir'], datetime.now(timezone.utc).strftime('%Y%m%d_%H-%M-%S'))
os.makedirs(run_dir, exist_ok=True)
shutil.copy(args.config, run_dir)
if hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None:
shutil.copy(args.deepspeed_config, run_dir)
# wait for all processes then get the most recent dir (may have just been created)
deepspeed.comm.barrier()
run_dir = get_most_recent_run_dir(config['output_dir'])
# Ugly hack so we can move quantized models from GPU to CPU, and back to GPU again without triggering quantization a second time.
bnb_cuda_old = bitsandbytes.nn.modules.Params4bit.cuda
def bnb_cuda_hijack(self, device):
if getattr(self, 'already_quantized', False):
self.data = self.data.to(device)
self.quant_state.to(device)
return self
self.already_quantized = True
return bnb_cuda_old(self, device)
bitsandbytes.nn.modules.Params4bit.cuda = bnb_cuda_hijack
pipeline_model, lora_model, lora_config = load_pipeline_model_with_lora(config, model_type)
parameters_to_train = [p for p in pipeline_model.parameters() if p.requires_grad]
optim_config = config['optimizer']
def get_optimizer(model_parameters):
lr = optim_config['lr']
optim_type = optim_config['type'].lower()
optimizer_kwargs = {
'params': model_parameters,
'lr': lr,
'betas': (optim_config.get('beta1', 0.9), optim_config.get('beta2', 0.99)),
'weight_decay': optim_config.get('weight_decay', 0.01),
'eps': optim_config.get('eps', 1e-6),
}
if optim_type == 'adamw':
optimizer_cls = deepspeed.ops.adam.FusedAdam
elif optim_type == 'adamw8bit':
optimizer_cls = bitsandbytes.optim.AdamW8bit
elif optim_type == 'adamw_kahan':
import optimi
optimizer_cls = optimi.AdamW
optimizer_kwargs['kahan_sum'] = optim_config.get('kahan_sum', True)
else:
raise NotImplementedError(optim_type)
if optim_config.get('use_loraplus', False):
loraplus_lr_ratio = optim_config.get('loraplus_lr_ratio', 16)
# TODO: handle params being thrown out here; why is it included in the first place?
# delete 'params' from optimizer_kwargs
del optimizer_kwargs['params']
return create_loraplus_optimizer(
model=pipeline_model,
optimizer_cls=optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
**optimizer_kwargs,
)
return optimizer_cls(**optimizer_kwargs)
kwargs = {}
if sampling_settings := config.get('sampling', None):
for k, v in sampling_settings.items():
kwargs['sampling_' + k] = v
rejected_sampling = config.get('rejected_sampling', False)
rl_config=config.get('rl', None)
model_engine, optimizer = engine.initialize(
args=args,
model=pipeline_model,
model_parameters=parameters_to_train,
optimizer=get_optimizer,
lora_model=lora_model,
config=ds_config,
tokenizer=tokenizer,
rl_config=rl_config,
rejected_sampling=rejected_sampling,
rejected_sampling_max_new_tokens=config.get('rejected_sampling_max_new_tokens', 1e9),
**kwargs,
)
# TODO: I have recently realized that we are setting things to fp16/bf16, even though all the DS
# config was not in fp16 / bf16 mode. DS being in fp16/bf16 changes things in many places, e.g.
# it can give you a BF16_Optimizer wrapper that accumulates grads in fp32, the communication dtype
# is different, etc. I need to really look through all the implications of this. This change is so
# that we keep the normal optimizer, but the communication dtype is changed so that we don't
# unnecessarily cast grads to fp32.
weight_dtype = DTYPE_MAP[config.get('lora_weight_dtype', config.get('model_weight_dtype', 'float32'))]
model_engine.communication_data_type = weight_dtype
# TODO: the main DeepSpeedEngine forces all parameters to the GPU, and also does things like
# broadcast all parameters from data parallel rank 0 to all other ranks. Thus, MLP offloading
# must come after engine.initialize(). If we want to avoid loading everything onto GPUs only
# to offload the MLPs, we have to rewrite a lot of code to work around things.
if config.get('offload_mlp_to_cpu', False):
assert config['activation_checkpointing'] # MLP offloading only works with activation checkpointing
for module in pipeline_model.modules():
if hasattr(module, 'move_mlp_to_cpu'):
module.move_mlp_to_cpu()
torch.cuda.empty_cache()
train_dataloader = dataloader.PipelineDataLoader(
train_data,
tokenizer,
model_engine.train_micro_batch_size_per_gpu(),
model_engine.gradient_accumulation_steps(),
model_engine.grid.get_data_parallel_world_size(),
model_engine.grid.get_data_parallel_rank(),
group_by_length=False if 'group_by_length' not in config else config['group_by_length'],
batch_size_tokens=None if 'batch_size_tokens' not in config else config['batch_size_tokens'],
return_dict=rejected_sampling,
rl=(rl_config is not None),
)
model_engine.set_dataloader(train_dataloader)
steps_per_epoch = len(train_dataloader) // model_engine.gradient_accumulation_steps()
model_engine.total_steps = steps_per_epoch * config['epochs']
if is_main_process():
# Warn if eval dataset is unusually large compared to the eval steps
eval_data_length = sum([len(eval_data) for eval_data in eval_data_map.values()])
train_data_length = len(train_data)
evals_per_epoch = steps_per_epoch / config['eval_steps']
relative_eval_time = evals_per_epoch * eval_data_length
# train step very roughly 3 times slower due to backprop + usually activation checkpointing is enabled
relative_train_time = train_data_length * 3
# Expect <=15% of our time spent evaluating vs training
fraction_evaling = relative_eval_time / (relative_eval_time + relative_train_time)
print()
print(
f'eval_data_length: {eval_data_length}, eval_steps: {config["eval_steps"]}; evals per epoch: {evals_per_epoch}. '
f'We will be spending approximately {fraction_evaling * 100:.2f}% of our time evaluating.'
)
if fraction_evaling > 0.15:
print(
'WARNING: eval dataset is unusually large compared to eval_steps. We will spend a lot of time evaluating. Lowering eval_size and/or bumping eval_steps is recommended.'
)
print()
# handle Deepspeed optimizer wrapper (e.g. BF16_Optimizer)
optimizer = getattr(optimizer, 'optimizer', optimizer)
warmup_steps = config.get('warmup_steps', 0)
# Fractional values less than 1 are converted into "fraction of epoch" worth of steps
if 0 < warmup_steps < 1:
warmup_steps = int(warmup_steps * steps_per_epoch)
if 'lr_scheduler' not in config or config['lr_scheduler'] == 'constant' or config['lr_scheduler'] == 'none':
lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
elif config['lr_scheduler'] == 'cosine':
total_steps = steps_per_epoch * config['epochs']
total_steps -= warmup_steps
lr_scheduler_kwargs = {
'optimizer': optimizer,
'T_max': total_steps,
}
if 'lr_min' in optim_config:
lr_scheduler_kwargs['eta_min'] = optim_config['lr_min']
# Normally, you would pass the lr_scheduler to deepspeed.initialize(). But we need the
# global batch_size in order to make the lr_scheduler.
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(**lr_scheduler_kwargs)
else:
raise NotImplementedError()
load_optimizer_states = config.get('load_optimizer_states', True)
# if resuming and not loading optimizer states, we can't use warmup or the LR never changes from the initial value (still don't know why)
if warmup_steps > 0 and load_optimizer_states:
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=1 / warmup_steps, total_iters=warmup_steps
)
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_scheduler, lr_scheduler], milestones=[warmup_steps]
)
model_engine.lr_scheduler = lr_scheduler
step = 1
if resume_from_checkpoint:
load_path, client_state = model_engine.load_checkpoint(
run_dir,
load_module_strict=False,
load_lr_scheduler_states='force_constant_lr' not in config,
load_optimizer_states=load_optimizer_states,
)
deepspeed.comm.barrier() # just so the print below doesn't get swamped
assert load_path is not None
train_dataloader.load_state_dict(client_state['custom_loader'])
step = client_state['step'] + 1
del client_state
# if we skip loading the optimizer states, we need to step the LR scheduler so we start at the right value
if not load_optimizer_states:
model_engine.lr_scheduler.step()
if is_main_process():
print(f'Resuming training from checkpoint. Resuming at epoch: {train_dataloader.epoch}, step: {step}')
if 'force_constant_lr' in config:
model_engine.lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
for pg in optimizer.param_groups:
pg['lr'] = config['force_constant_lr']
# this is a separate option, because if it's too high we might drop a significant fraction of the eval dataset
eval_gradient_accumulation_steps = (
config['eval_gradient_accumulation_steps'] if 'eval_gradient_accumulation_steps' in config else 1
)
# Eval dataset doesn't need to repeat; we just use this to track "epoch" so we know when we're done iterating over it.
eval_dataloaders = {
name: dataloader.PipelineDataLoader(
eval_data,
tokenizer,
model_engine.train_micro_batch_size_per_gpu(),
eval_gradient_accumulation_steps,
model_engine.grid.get_data_parallel_world_size(),
model_engine.grid.get_data_parallel_rank(),
shuffle=False,
group_by_length=False if 'group_by_length' not in config else config['group_by_length'],
batch_size_tokens=None if 'batch_size_tokens' not in config else config['batch_size_tokens'],
return_dict=rejected_sampling,
rl=(rl_config is not None),
)
for name, eval_data in eval_data_map.items()
}
tb_writer = SummaryWriter(log_dir=run_dir) if is_main_process() else None
epoch = train_dataloader.epoch
saver = Saver(model_engine, pipeline_model, train_dataloader, lora_config, run_dir, args, config)
epoch = train_dataloader.epoch
if config.get('eval_before_first_step', False) and not resume_from_checkpoint:
loss = evaluate(model_engine, eval_dataloaders, tb_writer, 0, eval_gradient_accumulation_steps)
saver.append_eval_results(loss, save_best=False)
while True:
metrics = model_engine.train_batch()
train_dataloader.sync_epoch()
if lora_config is not None:
keys_scaled, avg_norm, max_norm, norms = apply_max_norm_regularization(pipeline_model, config)
new_epoch = saver.process_epoch(epoch, step)
finished_epoch = True if new_epoch != epoch else False
if is_main_process() and step % config['logging_steps'] == 0:
write_metrics(tb_writer, 'train', metrics, step)
tb_writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], step)
# TODO: gather the weight norms across all stages in the pipelined model, not just the first.
if lora_config is not None and len(norms) > 0:
tb_writer.add_scalar('train/weights_scaled', keys_scaled, step)
tb_writer.add_scalar('train/weight_norm_avg', avg_norm, step)
tb_writer.add_scalar('train/weight_norm_max', max_norm, step)
tb_writer.add_histogram('train/weight_norm_hist', norms, step)
tb_writer.add_scalar('train/epoch', step / steps_per_epoch, step)
if step % config['eval_steps'] == 0:
loss = evaluate(model_engine, eval_dataloaders, tb_writer, step, eval_gradient_accumulation_steps)
saver.append_eval_results(loss)
saver.process_step(step)
if finished_epoch:
epoch = new_epoch
if epoch is None:
break
step += 1
if ((step - 1) % config['eval_steps'] != 0) and config.get('eval_after_last_step', False):
loss = evaluate(model_engine, eval_dataloaders, tb_writer, step - 1, eval_gradient_accumulation_steps)
saver.append_eval_results(loss)
if is_main_process():
print('TRAINING COMPLETE!')
================================================
FILE: utils/dataloader.py
================================================
import math
import os.path
import sys
sys.path.insert(0, os.path.abspath('axolotl/src'))
import accelerate
import torch
import transformers
from deepspeed import comm as dist
from torch.utils.data import DataLoader
from axolotl.utils.collators import DataCollatorForSeq2Seq
from utils.utils import is_main_process
# A100 wants padding to multiple of 64, other cards are efficient with smaller, so just do 64
PAD_TO_MULTIPLE = 64
# Splits an example (feature dict) along the batch dimension into a list of examples.
def split_batch(example, pieces):
input_ids = example['input_ids']
if is_main_process():
print(f'before GAS splitting, input_ids shape: {input_ids.shape}, total tokens: {input_ids.numel()}')
input_batch_size = input_ids.size(0)
split_size = input_batch_size // pieces
examples = [{} for _ in range(pieces)]
for key, tensor in example.items():
assert tensor.size(0) == input_batch_size
for i, tensor_slice in enumerate(torch.split(tensor, split_size)):
examples[i][key] = tensor_slice
return examples
# Merge lists of examples a and b, such that for each contiguous piece in the result, the first half comes from
# a and the second half from b. Used for DPO. The splitting must match how split_batch() does it.
def combine_piecewise(a, b, pieces):
assert len(a) == len(b)
split_size = len(a) // pieces
a_chunks = [a[i : i + split_size] for i in range(0, len(a), split_size)]
b_chunks = [b[i : i + split_size] for i in range(0, len(b), split_size)]
result = []
for a_chunk, b_chunk in zip(a_chunks, b_chunks):
result.extend(a_chunk)
result.extend(b_chunk)
return result
# Flattens a list of examples with batch dimension into a list of examples with no batch dimension.
def flatten_examples(examples):
result = []
for example in examples:
batch_size = example['input_ids'].size(0)
new_examples = [{} for _ in range(batch_size)]
for key, tensor in example.items():
assert tensor.size(0) == batch_size
for i, tensor_slice in enumerate(tensor):
new_examples[i][key] = tensor_slice
result.extend(new_examples)
return result
def example_to_tuple(example):
return (example['input_ids'], example['attention_mask'], example['labels']), None
def shuffle_list(l, seed):
g = torch.Generator()
g.manual_seed(seed)
shuffle_idx = torch.randperm(len(l), generator=g).tolist()
new_l = [l[i] for i in shuffle_idx]
return new_l
def batch_size_tokens_after_padding(batch):
return max(math.ceil(pair[1] / PAD_TO_MULTIPLE) * PAD_TO_MULTIPLE for pair in batch) * len(batch)
# A distributed batch sampler that supports grouping by length
class DistributedBatchSamper(torch.utils.data.Sampler):
def __init__(
self,
dataset,
batch_size,
num_replicas,
rank,
batch_size_multiplier=1,
shuffle=True,
group_by_length=False,
seed=0,
batch_size_tokens=None,
):
self.dataset = dataset
self.batch_size = batch_size
self.batch_size_tokens = batch_size_tokens
self.batch_size_multiplier = batch_size_multiplier
self.num_replicas = num_replicas
self.rank = rank
# every global batch must be evenly divisible by this amount
self.chunk_size = self.num_replicas * self.batch_size_multiplier
self.shuffle = shuffle
self.group_by_length = group_by_length
self.seed = seed
# Make list of (index, size). Sort or shuffle as needed.
indices = list(enumerate(self.dataset['length']))
if self.group_by_length:
indices.sort(key=lambda t: t[1])
elif self.shuffle:
indices = shuffle_list(indices, self.seed)
# Group indices together into global batches.
global_batches = []
current_batch = []
for i in range(0, len(indices), self.chunk_size):
slice = indices[i : i + self.chunk_size]
if len(slice) < self.chunk_size:
# pad with random examples if slice is too small
padding_size = self.chunk_size - len(slice)
shuffled_indices = shuffle_list(indices, self.seed + 1)
if padding_size < len(shuffled_indices):
slice += shuffled_indices[:padding_size]
else:
slice += (shuffled_indices * math.ceil(padding_size / len(shuffled_indices)))[:padding_size]
if self.should_emit_current_batch(current_batch, slice):
global_batches.append(current_batch)
current_batch = []
current_batch.extend(slice)
# Emit anything remaining
if len(current_batch) > 0:
global_batches.append(current_batch)
if self.shuffle:
global_batches = shuffle_list(global_batches, self.seed + 2)
# make sure the largest batch comes first to OOM sooner rather than later
largest_global_batch = 0
max_tokens = 0
for global_batch_idx, batch in enumerate(global_batches):
total_batch_tokens = batch_size_tokens_after_padding(batch)
if total_batch_tokens > max_tokens:
max_tokens = total_batch_tokens
largest_global_batch = global_batch_idx
global_batches[0], global_batches[largest_global_batch] = (
global_batches[largest_global_batch],
global_batches[0],
)
batches_for_this_rank = [
global_batch[self.rank : len(global_batch) : self.num_replicas] for global_batch in global_batches
]
self.indices = [[i for i, _ in batch] for batch in batches_for_this_rank]
def should_emit_current_batch(self, current_batch, slice):
if not self.batch_size_tokens:
batch_size_after_appending = len(current_batch) // self.chunk_size + 1
if batch_size_after_appending > self.batch_size:
return True
else:
return False
else:
global_batch_size_tokens = self.batch_size_tokens * self.chunk_size
current_batch_tokens_after_appending = batch_size_tokens_after_padding(current_batch + slice)
if len(current_batch) > 0 and current_batch_tokens_after_appending > global_batch_size_tokens:
return True
return False
def __iter__(self):
return iter(self.indices)
def __len__(self):
return len(self.indices)
class PipelineDataLoader:
def __init__(
self,
dataset,
tokenizer,
batch_size,
gradient_accumulation_steps,
data_parallel_world_size,
data_parallel_rank,
shuffle=True,
group_by_length=False,
pad_to_multiple_of=PAD_TO_MULTIPLE,
batch_size_tokens=None,
return_dict=False,
rl=False,
):
assert data_parallel_rank < data_parallel_world_size
self.dataset = dataset
self.tokenizer = tokenizer
self.batch_size = batch_size
self.batch_size_tokens = batch_size_tokens
self.gradient_accumulation_steps = gradient_accumulation_steps
self.pad_to_multiple_of = pad_to_multiple_of
self.return_dict = return_dict
self.rl = rl
self.data_sampler = DistributedBatchSamper(
dataset=dataset,
batch_size=self.batch_size,
batch_size_tokens=self.batch_size_tokens,
batch_size_multiplier=self.gradient_accumulation_steps,
num_replicas=data_parallel_world_size,
rank=data_parallel_rank,
shuffle=shuffle,
group_by_length=group_by_length,
)
data_collator = DataCollatorForSeq2Seq(self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
def collate_fn(examples, gradient_accumulation_steps=self.gradient_accumulation_steps):
has_batch_dimension = examples[0]['input_ids'].ndim == 2
if has_batch_dimension:
examples = flatten_examples(examples)
rejected_examples = []
for example in examples:
example.pop('length', None)
example.pop('token_type_ids', None)
rejected_example = {}
for key in list(example.keys()):
if 'rejected_' in key:
x = example.pop(key)
# Just drop the rejected_ entries if not doing RL. This allows normal SFT on just the
# accepted completions of a RL dataset.
if self.rl:
rejected_example[key.strip('rejected_')] = x
if rejected_example:
rejected_examples.append(rejected_example)
if rejected_examples:
examples = combine_piecewise(examples, rejected_examples, gradient_accumulation_steps)
return data_collator(examples)
self.collate_fn = collate_fn
self.epoch = 1
self.num_batches_pulled = 0
self.next_micro_batch = None
self.recreate_dataloader = False
self._create_dataloader()
self.data = self._pull_batches_from_dataloader()
def reset(self):
self.epoch = 1
self.num_batches_pulled = 0
self.next_micro_batch = None
self.data = self._pull_batches_from_dataloader()
def __iter__(self):
return self
def __len__(self):
return len(self.data_sampler) * self.gradient_accumulation_steps
def __next__(self):
if self.next_micro_batch is None:
self.next_micro_batch = next(self.data)
ret = self.next_micro_batch
try:
self.next_micro_batch = next(self.data)
except StopIteration:
if self.recreate_dataloader:
self._create_dataloader()
self.recreate_dataloader = False
self.data = self._pull_batches_from_dataloader()
self.num_batches_pulled = 0
self.next_micro_batch = next(self.data)
self.epoch += 1
return ret
def _pull_batches_from_dataloader(self):
for batch in self.dataloader:
self.num_batches_pulled += 1
for micro_batch in split_batch(batch, self.gradient_accumulation_steps):
if self.return_dict:
yield micro_batch
else:
# input to pipeline is (input_ids, attention_mask, labels)
# this needs to return (features, labels)
# it is OK if labels is None (the model just returns the loss anyway)
yield example_to_tuple(micro_batch)
def _create_dataloader(self):
self.dataloader = DataLoader(
self.dataset,
pin_memory=True,
batch_sampler=self.data_sampler,
collate_fn=self.collate_fn,
)
def state_dict(self):
return {
'epoch': self.epoch,
'num_batches_pulled': self.num_batches_pulled,
}
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
# -1 because by preloading the next micro_batch, it's always going to have one more batch
# pulled than the actual number of batches iterated by the caller.
self.num_batches_pulled = state_dict['num_batches_pulled'] - 1
self._create_dataloader()
self.dataloader = accelerate.skip_first_batches(self.dataloader, self.num_batches_pulled)
self.data = self._pull_batches_from_dataloader()
# Recreate the dataloader after the first pass so that it won't skip
# batches again (we only want it to skip batches the first time).
self.recreate_dataloader = True
# Only the first and last stages in the pipeline pull from the dataloader. Parts of the code need
# to know the epoch, so we synchronize the epoch so the processes that don't use the dataloader
# know the current epoch.
def sync_epoch(self):
process_group = dist.get_world_group()
result = [None] * dist.get_world_size(process_group)
torch.distributed.all_gather_object(result, self.epoch, group=process_group)
max_epoch = -1
for epoch in result:
max_epoch = max(epoch, max_epoch)
self.epoch = max_epoch
# for testing
if __name__ == '__main__':
tokenizer = transformers.AutoTokenizer.from_pretrained(sys.argv[1], local_files_only=True)
tokenizer.pad_token_id = 1000
from datasets import Dataset
data = []
for i in range(1, 41):
input_ids = torch.tensor([i] * i)
data.append(
{
'input_ids': input_ids,
'attention_mask': torch.ones_like(input_ids),
'labels': input_ids,
'length': len(input_ids),
}
)
dataset = Dataset.from_list(data)
# dataloader = PipelineDataLoader(dataset, tokenizer, batch_size=2, gradient_accumulation_steps=2, data_parallel_world_size=1, data_parallel_rank=0, group_by_length=True, pad_to_multiple_of=None)
# for batch in dataloader:
# if dataloader.epoch > 1:
# break
# print(batch)
# print()
batch_size = 2
gradient_accumulation_steps = 2
data_parallel_world_size = 2
data_parallel_rank = 0
dataloader = PipelineDataLoader(
dataset,
tokenizer,
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
data_parallel_world_size=data_parallel_world_size,
data_parallel_rank=data_parallel_rank,
shuffle=False,
group_by_length=False,
pad_to_multiple_of=None,
)
print(next(dataloader)[0][0])
print(next(dataloader)[0][0])
print(next(dataloader)[0][0])
print(next(dataloader)[0][0])
state_dict = dataloader.state_dict()
dataloader = PipelineDataLoader(
dataset,
tokenizer,
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
data_parallel_world_size=data_parallel_world_size,
data_parallel_rank=data_parallel_rank,
shuffle=False,
group_by_length=False,
pad_to_multiple_of=None,
)
dataloader.load_state_dict(state_dict)
print()
print('-' * 80)
print()
print(next(dataloader)[0][0])
print(next(dataloader)[0][0])
print(next(dataloader)[0][0])
print(next(dataloader)[0][0])
================================================
FILE: utils/dataset_utils.py
================================================
import os
import os.path
import sys
sys.path.insert(0, os.path.abspath('axolotl/src'))
import datasets
import torch
import yaml
from tqdm import tqdm
from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault
from utils.utils import is_main_process, zero_first
NUM_PROC = min(64, os.cpu_count())
def yield_sequences_from_token_batch(tokenizer, token_batch, sequence_len):
# Initialize sequence_tokens with BOS token if it exists
sequence_tokens = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else []
for tokens in tqdm(token_batch):
tokens = tokens.tolist()
assert len(tokens) > 0, 'Empty tokens list'
if tokens[-1] != tokenizer.eos_token_id:
tokens.append(tokenizer.eos_token_id)
idx = 0
# Skip the auto-generated BOS token if present
if tokenizer.bos_token_id is not None and tokens[0] == tokenizer.bos_token_id:
idx += 1
while idx < len(tokens):
# Calculate how many tokens are needed to fill the sequence
need = sequence_len - len(sequence_tokens)
taken = tokens[idx : idx + need]
idx += len(taken)
sequence_tokens.extend(taken)
if len(sequence_tokens) >= sequence_len:
assert len(sequence_tokens) == sequence_len
yield sequence_tokens
# Reset sequence_tokens with BOS token if it exists
sequence_tokens = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else []
# yield anything remaining
# TODO: disabled until I get training working with variable length sequences
# if len(sequence_tokens) > 0:
# yield sequence_tokens
def slice_into_chunks(x, sequence_len, overlap=0):
result = []
step = sequence_len - overlap
for i in range(0, len(x), step):
result.append(x[i : i + sequence_len])
return result
def load_raw_dataset(dataset_path, tokenizer, sequence_len, eval_size, overlap=0, subsample_documents=None):
if dataset_path.endswith('.txt'):
dataset = datasets.load_dataset('text', data_files=dataset_path, sample_by='document')['train']
elif dataset_path.endswith('.json') or dataset_path.endswith('.jsonl'):
dataset = datasets.load_dataset('json', data_files=dataset_path)['train']
else:
raise NotImplementedError()
dataset.set_format(type='torch')
if subsample_documents:
dataset = dataset.shuffle(seed=13).select(list(range(int(subsample_documents * len(dataset)))))
dataset = dataset.map(
lambda x: tokenizer(x['text']),
batched=True,
batch_size=10,
remove_columns=dataset.column_names,
desc='tokenizing',
num_proc=NUM_PROC,
)
# TODO: maybe do it this way instead
# dataset = dataset.map(lambda x: {'tokens': slice_into_chunks(x['tokens'][0], sequence_len, overlap=overlap)}, batched=True, batch_size=1)
dataset = dataset.map(
lambda x: {'input_ids': list(yield_sequences_from_token_batch(tokenizer, x['input_ids'], sequence_len))},
batched=True,
batch_size=None,
remove_columns=dataset.column_names,
desc='splitting',
)
dataset = dataset.map(
lambda x: {'attention_mask': torch.ones_like(x['input_ids']), 'labels': x['input_ids']},
desc='adding attention_mask and labels',
)
if eval_size > 0:
split_datasets = dataset.train_test_split(test_size=eval_size, shuffle=True, seed=42)
train_data = split_datasets['train']
eval_data = split_datasets['test']
else:
train_data = dataset
eval_data = None
return train_data, eval_data
def load_axolotl_dataset(dataset_path, tokenizer, sequence_len, eval_size):
with open(dataset_path) as f:
cfg = yaml.safe_load(f.read())
if 'val_set_size' not in cfg:
cfg['val_set_size'] = 0 if eval_size is None else eval_size
cfg['sequence_len'] = sequence_len
cfg['tokenizer_config'] = 'dummy'
# these don't matter, but they have to be set
cfg['batch_size'] = 1
cfg['num_epochs'] = 1
cfg['sequence_parallel_degree'] = 1
cfg = DictDefault(cfg)
train_data, eval_data, *_ = prepare_dataset(cfg, tokenizer)
train_data.set_format(type='torch')
if eval_data is not None:
eval_data.set_format(type='torch')
return train_data, eval_data
def load_pretokenized_dataset(dataset_path, tokenizer, sequence_len, eval_size):
ds = datasets.load_from_disk(dataset_path)
assert 'input_ids' in ds.column_names
assert 'attention_mask' in ds.column_names
assert 'labels' in ds.column_names
ds = ds.filter(lambda example: len(example['input_ids']) <= sequence_len, desc='dropping long sequences', num_proc=NUM_PROC)
if eval_size > 0:
split_datasets = ds.train_test_split(test_size=eval_size, shuffle=True, seed=42)
train_data = split_datasets['train']
eval_data = split_datasets['test']
else:
train_data = ds
eval_data = None
return train_data, eval_data
def load_single_dataset(dataset_config, tokenizer):
dataset_path = dataset_config['dataset_path']
dataset_type = dataset_config['dataset_type']
sequence_len = dataset_config['sequence_len']
eval_size = dataset_config.get('eval_size', 0)
subsample = dataset_config.get('subsample', None)
num_repeats = dataset_config.get('num_repeats', None)
if dataset_type in ['textfile', 'doclist']:
with zero_first(is_main_process()):
train_data, eval_data = load_raw_dataset(dataset_path, tokenizer, sequence_len, eval_size)
elif dataset_type == 'axolotl':
train_data, eval_data = load_axolotl_dataset(dataset_path, tokenizer, sequence_len, eval_size)
elif dataset_type == 'pretokenized':
with zero_first(is_main_process()):
train_data, eval_data = load_pretokenized_dataset(dataset_path, tokenizer, sequence_len, eval_size)
else:
raise NotImplementedError()
train_data = train_data.shuffle(seed=42)
if eval_data is not None:
eval_data = eval_data.shuffle(seed=42)
if subsample is not None:
assert 0 < subsample < 1
train_data = train_data.select(range(int(len(train_data) * subsample)))
if eval_data is not None:
eval_data = eval_data.select(range(int(len(eval_data) * subsample)))
def add_length(x):
length = len(x['input_ids'])
if 'rejected_input_ids' in x:
length = max(length, len(x['rejected_input_ids']))
return {'length': length}
with zero_first(is_main_process()):
train_data = train_data.map(add_length, desc='adding length field', num_proc=NUM_PROC)
if eval_data is not None:
eval_data = eval_data.map(add_length, desc='adding length field', num_proc=NUM_PROC)
if 'prompt_attention_mask' in train_data.column_names:
train_data = train_data.remove_columns('prompt_attention_mask')
if eval_data is not None:
eval_data = eval_data.remove_columns('prompt_attention_mask')
if num_repeats:
train_data = train_data.repeat(num_repeats)
if is_main_process():
print(f'train_data size: {len(train_data)}')
if eval_data is not None:
print(f'eval_data size: {len(eval_data)}')
return train_data, eval_data
def combine_datasets(dataset_list, config, sample_weights):
sample_weights = torch.tensor(sample_weights, dtype=torch.float32)
mode = config.get('dataset_combination_mode', 'concatenate')
if mode == 'concatenate':
dataset = datasets.concatenate_datasets(dataset_list)
elif mode == 'interleave':
if 'batch_size_tokens' in config:
# batches are formed so they have equal token counts, so interleave datasets based on token counts, not rows
avg_lengths = torch.tensor(
[dataset['length'].to(torch.float32).mean() for dataset in dataset_list], dtype=torch.float32
)
sample_weights = sample_weights / avg_lengths
sample_weights = sample_weights.to(
torch.float64
) # float64 or interleave_datasets complains that probs don't sum to 1
probs = sample_weights / sample_weights.sum()
dataset = datasets.interleave_datasets(
dataset_list,
probabilities=probs,
seed=42,
stopping_strategy=config.get('dataset_interleave_stopping_strategy', 'first_exhausted'),
)
else:
raise ValueError(mode)
return dataset
# TODO: reduce the extra unneeded left padding caused by accepted and rejected being collated
# together.
def process_dataset_for_rejected_sampling(dataset):
def _rejected_sampling_map_fn(example):
input_ids = example['input_ids']
attention_mask = example['attention_mask']
labels = example['labels']
assert input_ids.ndim == 1
assert attention_mask.ndim == 1
assert labels.ndim == 1
# Labels are -100 for masked tokens (the prompt).
label_mask = (labels == -100).to(torch.int32)
if (label_mask == 0).all():
# Raw text dataset
# TODO: allow configuring this
completion_start = len(label_mask) // 2
labels[:completion_start] = -100
else:
# index of first False
completion_start = torch.argmin(label_mask).item()
return {
'labels': labels,
'rejected_input_ids': input_ids[:completion_start],
'rejected_attention_mask': attention_mask[:completion_start],
'rejected_labels': labels[:completion_start],
}
return dataset.map(_rejected_sampling_map_fn, desc='Processing dataset for negative sampling', num_proc=NUM_PROC)
def load_datasets(config, tokenizer):
if 'datasets' not in config:
raise ValueError('Need to specify at least one dataset')
train_datasets = []
sample_weights = []
eval_datasets = {}
i = 0
for dataset_config in config['datasets']:
if 'name' in dataset_config:
name = dataset_config['name']
else:
name = f'dataset{i}'
i += 1
sample_weights.append(dataset_config.get('sample_weight', 1.0))
train, eval = load_single_dataset(dataset_config, tokenizer)
train_datasets.append(train)
if eval is not None:
eval_datasets[name] = eval
for dataset_config in config.get('eval_datasets', []):
if 'name' in dataset_config:
name = dataset_config['name']
else:
name = f'dataset{i}'
i += 1
eval, _ = load_single_dataset(dataset_config, tokenizer)
eval_datasets[name] = eval
if len(train_datasets) == 1:
train_dataset = train_datasets[0]
else:
with zero_first(is_main_process()):
train_dataset = combine_datasets(train_datasets, config, sample_weights=sample_weights)
if config.get('rejected_sampling', False):
assert 'rl' in config
train_dataset = process_dataset_for_rejected_sampling(train_dataset)
eval_datasets = {name: process_dataset_for_rejected_sampling(ds) for name, ds in eval_datasets.items()}
if 'rl' in config:
assert 'rejected_input_ids' in train_dataset.column_names
for eval_dataset in eval_datasets.values():
assert 'rejected_input_ids' in eval_dataset.column_names
train_dataset.set_format(type='torch')
for eval_dataset in eval_datasets.values():
eval_dataset.set_format(type='torch')
return train_dataset, eval_datasets
# for testing
if __name__ == '__main__':
import transformers
# from datasets import disable_caching
# disable_caching()
tokenizer = transformers.AutoTokenizer.from_pretrained(
sys.argv[1], local_files_only=True, use_fast=False, legacy=True
)
tokenizer.pad_token_id = 0
tokenizer.padding_side = 'right'
train_data1, eval_data1 = load_raw_dataset('/home/anon/data/test/txt/*.txt', tokenizer, 100, 0.5)
train_data2, eval_data2 = load_raw_dataset('/home/anon/data/test/json/*.jsonl', tokenizer, 100, 0.5)
print(len(train_data1))
print(len(train_data2))
================================================
FILE: utils/engine.py
================================================
from collections import deque
import time
import deepspeed
import torch
import transformers
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime import utils as ds_utils
from deepspeed.runtime.activation_checkpointing import checkpointing as ds_checkpointing
from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.runtime.pipe import p2p, schedule
from deepspeed.runtime.pipe.engine import (
BATCH_INPUT_TIMER,
PIPE_RECV_GRAD_TIMER,
PIPE_RECV_INPUT_TIMER,
PIPE_SEND_GRAD_TIMER,
PIPE_SEND_OUTPUT_TIMER,
TRAIN_BATCH_TIMER,
PipelineEngine,
)
from deepspeed.runtime.pipe.module import LayerSpec, PipelineModule
from deepspeed.runtime.pipe.schedule import (
BackwardPass,
BufferOpInstruction,
ForwardPass,
OptimizerStep,
PipeInstruction,
PipeSchedule,
RecvActivation,
RecvGrad,
ReduceGrads,
ReduceTiedGrads,
SendActivation,
SendGrad,
_is_even,
_is_odd,
)
from deepspeed.runtime.pipe.topology import ProcessTopology
from deepspeed.runtime.utils import PartitionedTensor
from torch import nn
from utils.utils import eta_str, log, is_main_process
from utils.dataloader import split_batch, example_to_tuple
def initialize(
args=None, model=None, config=None, **kwargs
):
assert model is not None, 'deepspeed.initialize requires a model'
dist_backend = get_accelerator().communication_backend_name()
dist.init_distributed(dist_backend=dist_backend)
if hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None:
config = args.deepspeed_config
mpu = model.mpu()
config_class = DeepSpeedConfig(config, mpu)
engine = CustomPipelineEngine(
args=args,
model=model,
mpu=mpu,
config=config,
config_class=config_class,
**kwargs,
)
return engine, engine.optimizer
def unpack_accepted_rejected(example):
batch_size = example['input_ids'].size(0)
half = batch_size // 2
for key, tensor in list(example.items()):
assert tensor.size(0) == batch_size
example[key] = tensor[:half]
example['rejected_'+key] = tensor[half:]
return example
class LoadMicroBatchMultipleBuffers(PipeInstruction):
def __init__(self, *buffer_ids, **kwargs):
super().__init__(buffer_ids=buffer_ids, **kwargs)
class ReferenceLogitsForwardPass(BufferOpInstruction):
pass
class CustomPipelineEngine(PipelineEngine):
def __init__(
self,
*args,
lora_model=None,
tokenizer=None,
rl_config=None,
rejected_sampling=False,
rejected_sampling_max_new_tokens=1e9,
sampling_temperature=1.0,
sampling_min_p=0.,
sampling_temperature_last=False,
**kwargs
):
super().__init__(*args, **kwargs)
self.total_steps = None
self.etas = deque()
self.rl_config = {}
# Assign list to avoid registering the nn.Module
self.lora_model = [lora_model]
self.tokenizer = tokenizer
self.rl_config = rl_config
self.rejected_sampling = rejected_sampling
self.rejected_sampling_max_new_tokens = rejected_sampling_max_new_tokens
eos_token_ids = set()
if self.tokenizer is not None and self.tokenizer.eos_token_id is not None:
eos_token_ids.add(self.tokenizer.eos_token_id)
model_config = self.module.model.config
if model_config.eos_token_id:
model_eos_token_ids = model_config.eos_token_id
if isinstance(model_eos_token_ids, int):
model_eos_token_ids = [model_eos_token_ids]
eos_token_ids.update(model_eos_token_ids)
self.eos_token_ids = eos_token_ids
# Sampling configuration. Only supports logits processors that don't use input_ids.
self.logits_processor = transformers.LogitsProcessorList()
temp = transformers.TemperatureLogitsWarper(float(sampling_temperature))
if sampling_min_p > 0:
self.logits_processor.append(transformers.MinPLogitsWarper(float(sampling_min_p)))
if sampling_temperature_last:
self.logits_processor.append(temp)
else:
self.logits_processor.insert(0, temp)
def set_dataloader(self, loader):
self.collate_fn = loader.collate_fn
if self.is_first_stage() or self.is_last_stage():
self.training_dataloader = loader
self.data_iterator = iter(self.training_dataloader)
def train_batch(self):
if not torch._C.is_grad_enabled():
raise RuntimeError(f'train_batch() requires gradients enabled. Use eval_batch() instead.')
self.timers(TRAIN_BATCH_TIMER).start()
# sequence length may change between macro batches (but not between gradient accumulation steps)
self.reset_activation_shape()
train_iterator = self.data_iterator
# Negative sampling
if self.rejected_sampling:
assert self.rl_config
self.module.eval()
model_inputs = self._sample_from_iterator(train_iterator, self.collate_fn)
dist.barrier()
if model_inputs is not None:
self.set_dataiterator(iter(model_inputs))
self.reset_activation_shape()
self.module.train()
self._compute_loss = True
# Do the work
if self.rl_config:
method = self.rl_config.get('method', None)
if method == 'dpo':
sched = DPOTrainSchedule(micro_batches=self.micro_batches, stages=self.num_stages, stage_id=self.stage_id)
else:
raise NotImplementedError(method)
else:
sched = schedule.TrainSchedule(micro_batches=self.micro_batches, stages=self.num_stages, stage_id=self.stage_id)
self._exec_schedule(sched)
agg_losses = self._aggregate_total_losses()
# Actual training loss is always the first item.
self.agg_train_loss = agg_losses[0].mean()
self.timers(TRAIN_BATCH_TIMER).stop()
if self.global_steps % self.steps_per_print() == 0:
if self.global_rank == 0:
elapsed = self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True) / 1000.0
iter_time = elapsed / self.steps_per_print()
eta = iter_time * (self.total_steps - self.global_steps)
self.etas.append(eta)
while len(self.etas) > 10:
self.etas.popleft()
rolling_eta = sum(self.etas) / len(self.etas)
tput = self.train_batch_size() / iter_time
log(
f'step: {self.global_steps:>5} / {self.total_steps:>5} '
f'loss: {self.agg_train_loss:0.4f} '
f'iter time (s): {iter_time:0.3f} '
f'samples/sec: {tput:0.3f} '
f'eta: {eta_str(rolling_eta)} '
)
else:
self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True)
# Monitoring
if self.global_rank == 0 and self.monitor.enabled:
self.summary_events = [
('Train/Samples/train_loss', self.agg_train_loss.mean().item(), self.global_samples)
]
self.monitor.write_events(self.summary_events)
if self.wall_clock_breakdown() and self.global_steps % self.steps_per_print() == 0:
self.timers.log(
[
PIPE_SEND_OUTPUT_TIMER,
PIPE_SEND_GRAD_TIMER,
PIPE_RECV_INPUT_TIMER,
PIPE_RECV_GRAD_TIMER,
]
)
# Restore the training iterator
self.set_dataiterator(train_iterator)
return agg_losses
def eval_batch(self, data_iter):
# sequence length may change between macro batches (but not between gradient accumulation steps)
self.reset_activation_shape()
self.module.eval()
self._compute_loss = True
# Use the provided data iterator
train_iterator = self.data_iterator
self.set_dataiterator(data_iter)
# Negative sampling
if self.rejected_sampling:
assert self.rl_config
dist.barrier()
model_inputs = self._sample_from_iterator(data_iter, self.collate_fn)
if model_inputs is not None:
self.set_dataiterator(iter(model_inputs))
self.reset_activation_shape()
# Do the work
if self.rl_config:
method = self.rl_config.get('method', None)
if method == 'dpo':
sched = DPOInferenceSchedule(micro_batches=self.micro_batches, stages=self.num_stages, stage_id=self.stage_id)
else:
raise NotImplementedError(method)
else:
sched = schedule.InferenceSchedule(
micro_batches=self.micro_batches, stages=self.num_stages, stage_id=self.stage_id
)
# prevent dead-lock with multiple evals sequence
dist.barrier()
with torch.no_grad():
self._exec_schedule(sched)
# list of losses
agg_eval_losses = self._aggregate_total_losses()
if self.global_rank == 0 and self.monitor.enabled:
self.summary_events = [('Train/Samples/eval_loss', agg_eval_losses[0].mean().item(), self.global_samples)]
self.monitor.write_events(self.summary_events)
# Restore the training iterator
self.set_dataiterator(train_iterator)
return agg_eval_losses
def sample_batch(self, prompts):
assert isinstance(prompts, (list, tuple))
self.reset_activation_shape()
self.module.eval()
self.module.set_sampling_mode(True)
original_micro_batches = self.micro_batches
self.micro_batches = len(prompts)
dist.barrier()
if self.is_first_stage():
# Tokenizer returns dict with 'input_ids', 'attention_mask' keys.
# Tensors have batch dimension because we pass list of prompts.
examples = []
for prompt in prompts:
if not isinstance(prompt, (list, tuple)):
prompt = [prompt]
examples.append(self.tokenizer(prompt, return_tensors='pt', padding=True))
else:
examples = None
with torch.no_grad():
examples = self._exec_sampling_schedule(examples)
if self.is_first_stage():
text = [self.tokenizer.batch_decode(example['input_ids']) for example in examples]
else:
text = None
self.micro_batches = original_micro_batches
self.module.set_sampling_mode(False)
return text
def _sample_from_iterator(self, data_iter, collate_fn):
if data_iter is not None:
examples = [unpack_accepted_rejected(next(data_iter)) for _ in range(self.micro_batches)]
# TODO: allow configuring this
max_total_tokens = max(example['input_ids'].size(1) for example in examples) * 2
else:
examples = None
# This is okay, max_total_tokens is only checked on the first stage.
max_total_tokens = None
self.module.eval()
self.module.set_sampling_mode(True)
dist.barrier()
with torch.no_grad():
examples = self._exec_sampling_schedule(examples, feature_prefix='rejected_', max_total_tokens=max_total_tokens)
if is_main_process():
input_ids = examples[0]['rejected_input_ids'][0]
attention_mask = examples[0]['rejected_attention_mask'][0]
start = torch.argmax(attention_mask)
end = len(attention_mask) - torch.argmax(torch.flip(attention_mask, (0,)))
text = self.tokenizer.decode(input_ids[start:end])
print(f'Example of sampled rejected completion:\n{text}')
self.module.set_sampling_mode(False)
if examples is not None:
batch = collate_fn(examples, gradient_accumulation_steps=self.micro_batches)
model_inputs = [example_to_tuple(micro_batch) for micro_batch in split_batch(batch, self.micro_batches)]
return model_inputs
else:
return None
def _aggregate_total_losses(self):
all_agg_outputs = []
# gather each output for all the gradient accumulation steps
grouped_outputs = [list(x) for x in zip(*self.fwd_outputs)]
# if any are scalar, make them dim 1 so we can concat across DP ranks
for outputs in grouped_outputs:
for i, output in enumerate(outputs):
if output.dim() == 0:
outputs[i] = torch.unsqueeze(output, 0)
if self.is_last_stage():
agg_sizes = []
# loop to gather all the outputs across DP ranks
for outputs in grouped_outputs:
# concat all the grad_accum_steps
concat_outputs = torch.cat(outputs)
if self.is_data_parallel:
# might be different sizes across DP ranks, so, gather all the sizes
sizes = [None] * self.grid.get_data_parallel_world_size()
torch.distributed.all_gather_object(
sizes, concat_outputs.size(), group=self.grid.get_data_parallel_group()
)
# once we know all the sizes we can gather the results across DP ranks
gather_result = [torch.zeros(size).to(self.device) for size in sizes]
dist.all_gather(gather_result, concat_outputs, group=self.grid.get_data_parallel_group())
# and finally, concat
agg_output = torch.cat(gather_result)
else:
agg_output = concat_outputs
agg_sizes.append(agg_output.size())
all_agg_outputs.append(agg_output)
# send the sizes, then broadcast to the PP ranks
if self.is_pipe_parallel:
torch.distributed.broadcast_object_list(
[agg_sizes], src=self.global_rank, group=self.grid.get_pipe_parallel_group()
)
for agg_output in all_agg_outputs:
dist.broadcast(tensor=agg_output, src=self.global_rank, group=self.grid.get_pipe_parallel_group())
else:
# get the outputs from the last stage
src_rank = self.grid.stage_to_global(self.num_stages - 1)
assert src_rank in self.grid.pp_group
result = [None]
torch.distributed.broadcast_object_list(result, src=src_rank, group=self.grid.get_pipe_parallel_group())
agg_sizes = result[0]
for agg_size in agg_sizes:
agg_output = torch.zeros(agg_size).to(self.device)
dist.broadcast(tensor=agg_output, src=src_rank, group=self.grid.get_pipe_parallel_group())
all_agg_outputs.append(agg_output)
return all_agg_outputs
# We override this to handle the model returning a list of "losses", but only doing backprop on the first.
def _exec_forward_pass(self, buffer_id):
self.tput_timer.start()
self.mem_status('BEFORE FWD', reset_max=True)
if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple):
inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id])
else:
inputs = self.pipe_buffers['inputs'][buffer_id].clone()
# collect the partitioned input from the previous stage
if self.is_pipe_partitioned and not self.is_first_stage():
part_input = PartitionedTensor.from_meta(
meta=inputs[0], local_part=inputs[1], group=self.grid.get_slice_parallel_group()
)
inputs = (part_input.full(), *inputs[2:])
inputs[0].requires_grad = True
# skip mask
# inputs[1].requires_grad = True
part_input = None
inputs = inputs[0] if len(inputs) == 1 else inputs
self.pipe_buffers['inputs'][buffer_id] = inputs
# inputs has no gradient because it is from a cloned tensor
outputs = super(PipelineEngine, self).forward(inputs)
# Reset activation checkpointing buffers.
# Need to call this between evaluation iterations
if not self.module.training:
ds_checkpointing.reset()
# Partition the outputs if we are not the last stage
if self.is_pipe_partitioned and not self.is_last_stage():
if isinstance(outputs, tuple):
first_output = outputs[0]
# TODO: Improve pipe partitioning to pass multiple tensors that require grads
assert all(torch.is_tensor(elt) and elt.requires_grad is False for elt in outputs[1:])
outputs_tail = outputs[1:]
elif torch.is_tensor(outputs):
first_output = outputs
outputs_tail = []
else:
raise ValueError('expecting a tensor or a tuple of tensors')
part = PartitionedTensor(tensor=first_output, group=self.grid.get_slice_parallel_group())
# Clear the large output data, but save the computation graph
first_output.data = torch.zeros(1)
self.pipe_buffers['output_tensors'][buffer_id] = first_output
# Inject the partitioned tensor into the output before sending
outputs = (part.to_meta(), part.data(), *outputs_tail)
part = None
self.pipe_buffers['outputs'][buffer_id] = outputs
# Optionally compute loss on the last device
if self.is_last_stage():
if self._compute_loss and self.module.loss_fn is not None:
labels = self.pipe_buffers['labels'][buffer_id]
losses = self.module.loss_fn(outputs, labels)
else:
# Some models just return loss from forward()
losses = outputs
if self.eval_return_logits:
self.outputs = outputs
if isinstance(losses, torch.Tensor):
self.loss = losses
self.fwd_outputs.append([self.loss.detach()])
else:
self.loss = losses[0]
self.fwd_outputs.append([l.detach() for l in losses])
def _exec_load_micro_batch_multiple_buffers(self, buffer_ids):
if self.wall_clock_breakdown():
self.timers(BATCH_INPUT_TIMER).start()
batch = self._next_batch()
if self.is_first_stage():
loaded = None
if torch.is_tensor(batch[0]):
loaded = batch[0].clone().to(self.device).detach()
if (
self._config.pipeline['activation_checkpoint_interval'] > 0
and self._config.pipeline['use_reentrant']
):
loaded.requires_grad = loaded.is_floating_point()
else:
assert isinstance(batch[0], (tuple, list))
# Assume list or tuple
loaded = []
for x in batch[0]:
assert torch.is_tensor(x)
mine = x.clone().detach().to(self.device)
if (
self._config.pipeline['activation_checkpoint_interval'] > 0
and self._config.pipeline['use_reentrant']
):
mine.requires_grad = mine.is_floating_point()
loaded.append(mine)
loaded = tuple(loaded)
for buffer_id in buffer_ids:
self.pipe_buffers['inputs'][buffer_id] = loaded
if self.is_last_stage():
loaded = batch[1]
if torch.is_tensor(batch[1]):
loaded = batch[1].to(self.device)
# XXX: torch 1.6.0 DataLoader will auto convert tuple to list
elif isinstance(batch[1], (tuple, list)):
loaded = []
for x in batch[1]:
assert torch.is_tensor(x)
x = x.to(self.device).detach()
loaded.append(x)
loaded = tuple(loaded)
for buffer_id in buffer_ids:
self.pipe_buffers['labels'][buffer_id] = loaded
if self.wall_clock_breakdown():
self.timers(BATCH_INPUT_TIMER).stop()
@torch.no_grad()
def _exec_reference_logits_forward_pass(self, buffer_id):
self.lora_model[0].disable_adapter_layers()
self.module.set_dpo_reference_mode(True)
if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple):
inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id])
else:
inputs = self.pipe_buffers['inputs'][buffer_id].clone()
# collect the partitioned input from the previous stage
if self.is_pipe_partitioned and not self.is_first_stage():
if self.pipe_partition_input_meta_cache is None:
self.pipe_partition_input_meta_cache = inputs[0].to('cpu')
part_input = PartitionedTensor.from_meta(
meta=self.pipe_partition_input_meta_cache,
local_part=inputs[1],
group=self.grid.get_slice_parallel_group(),
)
inputs = (part_input.full(), *inputs[2:])
inputs[0].requires_grad = True
# skip mask
# inputs[1].requires_grad = True
part_input = None
inputs = inputs[0] if len(inputs) == 1 else inputs
self.pipe_buffers['inputs'][buffer_id] = inputs
# inputs has no gradient because it is from a cloned tensor
outputs = super(PipelineEngine, self).forward(inputs)
# Reset activation checkpointing buffers.
# Need to call this between evaluation iterations
if not self.module.training:
ds_checkpointing.reset()
# Partition the outputs if we are not the last stage
if self.is_pipe_partitioned and not self.is_last_stage():
if isinstance(outputs, tuple):
first_output = outputs[0]
# TODO: Improve pipe partitioning to pass multiple tensors that require grads
assert all(torch.is_tensor(elt) and elt.requires_grad is False for elt in outputs[1:])
outputs_tail = outputs[1:]
elif torch.is_tensor(outputs):
first_output = outputs
outputs_tail = []
else:
raise ValueError('expecting a tensor or a tuple of tensors')
part = PartitionedTensor(tensor=first_output, group=self.grid.get_slice_parallel_group())
# Clear the large output data, but save the computation graph
first_output.data = torch.zeros(1, device=first_output.data.device)
self.pipe_buffers['output_tensors'][buffer_id] = first_output
# Inject the partitioned tensor into the output before sending
outputs = (part.to_meta(), part.data(), *outputs_tail)
part = None
self.pipe_buffers['outputs'][buffer_id] = outputs
self.lora_model[0].enable_adapter_layers()
self.module.set_dpo_reference_mode(False)
def _exec_send_micro_batch_id(self, send_micro_batch_id):
assert isinstance(send_micro_batch_id, int)
if self.num_stages == 1:
return send_micro_batch_id
send_micro_batch_id = torch.tensor(send_micro_batch_id, device=self.device)
recv_micro_batch_id = torch.tensor(-1, device=self.device)
if _is_even(self.stage_id):
if not self.is_last_stage():
p2p.send(send_micro_batch_id, self.next_stage)
if not self.is_first_stage():
p2p.recv(recv_micro_batch_id, self.prev_stage)
else:
if not self.is_first_stage():
p2p.recv(recv_micro_batch_id, self.prev_stage)
if not self.is_last_stage():
p2p.send(send_micro_batch_id, self.next_stage)
# last s
gitextract_z9_xsnea/
├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── examples/
│ ├── alpaca.yml
│ ├── capybara.yml
│ ├── config.toml
│ ├── config_dpo.toml
│ ├── converted_dpo_dataset.yml
│ ├── ds_config.json
│ └── ultrafeedback.yml
├── kernels/
│ ├── cross_entropy_loss.py
│ └── utils.py
├── models/
│ ├── layers.py
│ ├── models.py
│ └── pipeline_model.py
├── pyproject.toml
├── requirements.txt
├── tools/
│ ├── convert_dpo_dataset_to_chat_format.py
│ ├── convert_ds_checkpoint_to_lora.py
│ ├── merge_lora.py
│ └── test_sampling.py
├── train.py
└── utils/
├── dataloader.py
├── dataset_utils.py
├── engine.py
├── hqq_utils.py
├── saver.py
├── unsloth_utils.py
└── utils.py
SYMBOL INDEX (222 symbols across 17 files)
FILE: kernels/cross_entropy_loss.py
function _cross_entropy_forward (line 28) | def _cross_entropy_forward(
function _chunked_cross_entropy_forward (line 99) | def _chunked_cross_entropy_forward(
function _cross_entropy_backward (line 180) | def _cross_entropy_backward(
class Fast_CrossEntropyLoss (line 245) | class Fast_CrossEntropyLoss(torch.autograd.Function):
method forward (line 247) | def forward(ctx, logits, labels, logit_scale=1.0):
method backward (line 314) | def backward(ctx, dlosses):
function fast_cross_entropy_loss (line 352) | def fast_cross_entropy_loss(logits, labels, logit_scale=1.0):
FILE: kernels/utils.py
function device_warp_size (line 26) | def device_warp_size():
function calculate_settings (line 33) | def calculate_settings(n):
function QUANT_STATE (line 62) | def QUANT_STATE(W):
function get_lora_parameters (line 69) | def get_lora_parameters(proj):
function fast_dequantize (line 88) | def fast_dequantize(W, quant_state=None, out=None):
function fast_gemv (line 144) | def fast_gemv(X, W, quant_state, out=None):
function fast_linear_forward (line 222) | def fast_linear_forward(proj, X, temp_lora=None, out=None):
function matmul_lora (line 265) | def matmul_lora(X, W, W_quant, A, B, s, out=None):
FILE: models/layers.py
function move_data_to_device (line 12) | def move_data_to_device(module, device):
function set_data (line 27) | def set_data(module, data):
function move_experts_to_device (line 38) | def move_experts_to_device(experts, device, num_experts_to_offload):
function set_experts_data (line 48) | def set_experts_data(experts, orig_data):
function entropy_fn (line 55) | def entropy_fn(logits):
function top_k_accuracy (line 65) | def top_k_accuracy(logits, labels, k_list, ignore_index=-100):
class LayerSpec (line 77) | class LayerSpec(ds_pipe_module.LayerSpec):
method __init__ (line 78) | def __init__(self, typename, *module_args, **module_kwargs):
method build (line 81) | def build(self):
method estimated_size (line 86) | def estimated_size(self):
class OutputLayer (line 92) | class OutputLayer(nn.Module):
method __init__ (line 93) | def __init__(
method forward (line 121) | def forward(self, inputs):
function load_balancing_loss_func (line 238) | def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: tor...
class MixtralOutputLayer (line 255) | class MixtralOutputLayer(OutputLayer):
method __init__ (line 256) | def __init__(
method forward (line 271) | def forward(self, inputs):
class InputLayer (line 288) | class InputLayer(nn.Module):
method __init__ (line 289) | def __init__(self, model):
method model (line 298) | def model(self):
method forward (line 301) | def forward(self, inputs):
class LlamaRMSNormPipe (line 351) | class LlamaRMSNormPipe(nn.Module):
method __init__ (line 352) | def __init__(self, loader_util, orig):
method forward (line 357) | def forward(self, inputs):
class LlamaDecoderLayerPipe (line 362) | class LlamaDecoderLayerPipe(nn.Module):
method __init__ (line 363) | def __init__(self, pipeline_model, loader_util, orig):
method forward (line 377) | def forward(self, inputs):
method move_mlp_to_cpu (line 405) | def move_mlp_to_cpu(self):
method move_mlp_to_device (line 418) | def move_mlp_to_device(self, device):
class Phi3DecoderLayerPipe (line 424) | class Phi3DecoderLayerPipe(LlamaDecoderLayerPipe):
method __init__ (line 425) | def __init__(self, *args, **kwargs):
method move_mlp_to_cpu (line 428) | def move_mlp_to_cpu(self):
method move_mlp_to_device (line 438) | def move_mlp_to_device(self, device):
class MixtralDecoderLayerPipe (line 443) | class MixtralDecoderLayerPipe(LlamaDecoderLayerPipe):
method __init__ (line 444) | def __init__(self, *args, **kwargs):
method forward (line 448) | def forward(self, inputs):
method move_mlp_to_cpu (line 478) | def move_mlp_to_cpu(self):
method move_mlp_to_device (line 486) | def move_mlp_to_device(self, device):
class Gemma3InputLayer (line 492) | class Gemma3InputLayer(nn.Module):
method __init__ (line 493) | def __init__(self, model):
method model (line 503) | def model(self):
method forward (line 506) | def forward(self, inputs):
class Gemma3DecoderLayerPipe (line 557) | class Gemma3DecoderLayerPipe(nn.Module):
method __init__ (line 558) | def __init__(self, pipeline_model, loader_util, orig):
method forward (line 566) | def forward(self, inputs):
method move_mlp_to_cpu (line 602) | def move_mlp_to_cpu(self):
method move_mlp_to_device (line 615) | def move_mlp_to_device(self, device):
class Gemma3RMSNormPipe (line 621) | class Gemma3RMSNormPipe(nn.Module):
method __init__ (line 622) | def __init__(self, loader_util, orig):
method forward (line 627) | def forward(self, inputs):
FILE: models/models.py
class LlamaForCausalLMPipe (line 27) | class LlamaForCausalLMPipe(PipelineModel, transformers.LlamaForCausalLM):
method __init__ (line 28) | def __init__(self, config, quantization_config):
method to_layer_specs (line 37) | def to_layer_specs(self):
class Qwen2ForCausalLMPipe (line 59) | class Qwen2ForCausalLMPipe(PipelineModel, transformers.Qwen2ForCausalLM):
method __init__ (line 60) | def __init__(self, config, quantization_config):
method to_layer_specs (line 69) | def to_layer_specs(self):
class CohereForCausalLMPipe (line 88) | class CohereForCausalLMPipe(PipelineModel, transformers.CohereForCausalLM):
method __init__ (line 89) | def __init__(self, config, quantization_config):
method to_layer_specs (line 98) | def to_layer_specs(self):
class Phi3ForCausalLMPipe (line 122) | class Phi3ForCausalLMPipe(PipelineModel, transformers.Phi3ForCausalLM):
method __init__ (line 123) | def __init__(self, config, quantization_config):
method to_layer_specs (line 132) | def to_layer_specs(self):
class Gemma2ForCausalLMPipe (line 150) | class Gemma2ForCausalLMPipe(PipelineModel, transformers.Gemma2ForCausalLM):
method __init__ (line 151) | def __init__(self, config, quantization_config):
method to_layer_specs (line 160) | def to_layer_specs(self):
class MistralForCausalLMPipe (line 185) | class MistralForCausalLMPipe(PipelineModel, transformers.MistralForCausa...
method __init__ (line 186) | def __init__(self, config, quantization_config):
method to_layer_specs (line 195) | def to_layer_specs(self):
class MixtralForCausalLMPipe (line 213) | class MixtralForCausalLMPipe(PipelineModel, transformers.MixtralForCausa...
method __init__ (line 214) | def __init__(self, config, quantization_config):
method to_layer_specs (line 227) | def to_layer_specs(self):
class Gemma3ForCausalLMPipe (line 247) | class Gemma3ForCausalLMPipe(PipelineModel, transformers.Gemma3ForCausalLM):
method __init__ (line 248) | def __init__(self, config, quantization_config):
method to_layer_specs (line 257) | def to_layer_specs(self):
class Cohere2ForCausalLMPipe (line 282) | class Cohere2ForCausalLMPipe(PipelineModel, transformers.Cohere2ForCausa...
method __init__ (line 283) | def __init__(self, config, quantization_config):
method to_layer_specs (line 292) | def to_layer_specs(self):
FILE: models/pipeline_model.py
class PipelineModel (line 22) | class PipelineModel(nn.Module):
method __init__ (line 23) | def __init__(self, config, quantization_config, model_config):
method to_layer_specs (line 43) | def to_layer_specs(self):
method set_dpo_reference_mode (line 46) | def set_dpo_reference_mode(self, dpo_reference_mode):
method set_sampling_mode (line 49) | def set_sampling_mode(self, sampling_mode):
method set_cache (line 66) | def set_cache(self, micro_batch_id):
function _partial_module_name_match (line 70) | def _partial_module_name_match(full_name, list_to_match):
function _replace_with_quantized_linear (line 74) | def _replace_with_quantized_linear(parent_modules_map, name, full_name, ...
function _replace_with_bnb_linear (line 83) | def _replace_with_bnb_linear(parent_modules_map, name, full_name, quanti...
function _replace_with_hqq_linear (line 126) | def _replace_with_hqq_linear(parent_modules_map, name, full_name, quanti...
function _recursively_replace_with_quantized_linear (line 151) | def _recursively_replace_with_quantized_linear(
class LoaderUtil (line 191) | class LoaderUtil:
method __init__ (line 192) | def __init__(self, model_path, quantization_config, modules_to_not_qua...
method get_partial_state_dict (line 210) | def get_partial_state_dict(self, leaf_file):
method maybe_quantize (line 218) | def maybe_quantize(self, module):
method load_state_dict_into_module (line 231) | def load_state_dict_into_module(self, module):
FILE: tools/convert_dpo_dataset_to_chat_format.py
function convert (line 15) | def convert(x):
FILE: tools/convert_ds_checkpoint_to_lora.py
function convert_ds_checkpoint_to_lora (line 12) | def convert_ds_checkpoint_to_lora(ds_checkpoint_dir, lora_output_dir):
FILE: tools/merge_lora.py
function find_lora_weights (line 45) | def find_lora_weights(key):
FILE: tools/test_sampling.py
function bnb_cuda_hijack (line 70) | def bnb_cuda_hijack(self, device):
FILE: train.py
function print_model_info (line 47) | def print_model_info(model):
function set_config_defaults (line 61) | def set_config_defaults(config):
function get_most_recent_run_dir (line 66) | def get_most_recent_run_dir(output_dir):
function write_metrics (line 70) | def write_metrics(tb_writer, prefix, metrics, step):
function evaluate_single (line 144) | def evaluate_single(model_engine, name, eval_dataloader, tb_writer, step...
function evaluate (line 168) | def evaluate(model_engine, eval_dataloaders, tb_writer, step, eval_gradi...
function apply_max_norm_regularization (line 185) | def apply_max_norm_regularization(model, config):
function parse_layers_to_transform (line 232) | def parse_layers_to_transform(spec):
function one_at_a_time (line 242) | def one_at_a_time():
function load_pipeline_model_with_lora (line 249) | def load_pipeline_model_with_lora(config, model_type):
function bnb_cuda_hijack (line 469) | def bnb_cuda_hijack(self, device):
function get_optimizer (line 485) | def get_optimizer(model_parameters):
FILE: utils/dataloader.py
function split_batch (line 25) | def split_batch(example, pieces):
function combine_piecewise (line 41) | def combine_piecewise(a, b, pieces):
function flatten_examples (line 54) | def flatten_examples(examples):
function example_to_tuple (line 67) | def example_to_tuple(example):
function shuffle_list (line 71) | def shuffle_list(l, seed):
function batch_size_tokens_after_padding (line 79) | def batch_size_tokens_after_padding(batch):
class DistributedBatchSamper (line 84) | class DistributedBatchSamper(torch.utils.data.Sampler):
method __init__ (line 85) | def __init__(
method should_emit_current_batch (line 160) | def should_emit_current_batch(self, current_batch, slice):
method __iter__ (line 174) | def __iter__(self):
method __len__ (line 177) | def __len__(self):
class PipelineDataLoader (line 181) | class PipelineDataLoader:
method __init__ (line 182) | def __init__(
method reset (line 248) | def reset(self):
method __iter__ (line 254) | def __iter__(self):
method __len__ (line 257) | def __len__(self):
method __next__ (line 260) | def __next__(self):
method _pull_batches_from_dataloader (line 276) | def _pull_batches_from_dataloader(self):
method _create_dataloader (line 288) | def _create_dataloader(self):
method state_dict (line 296) | def state_dict(self):
method load_state_dict (line 302) | def load_state_dict(self, state_dict):
method sync_epoch (line 317) | def sync_epoch(self):
FILE: utils/dataset_utils.py
function yield_sequences_from_token_batch (line 21) | def yield_sequences_from_token_batch(tokenizer, token_batch, sequence_len):
function slice_into_chunks (line 50) | def slice_into_chunks(x, sequence_len, overlap=0):
function load_raw_dataset (line 58) | def load_raw_dataset(dataset_path, tokenizer, sequence_len, eval_size, o...
function load_axolotl_dataset (line 101) | def load_axolotl_dataset(dataset_path, tokenizer, sequence_len, eval_size):
function load_pretokenized_dataset (line 120) | def load_pretokenized_dataset(dataset_path, tokenizer, sequence_len, eva...
function load_single_dataset (line 136) | def load_single_dataset(dataset_config, tokenizer):
function combine_datasets (line 190) | def combine_datasets(dataset_list, config, sample_weights):
function process_dataset_for_rejected_sampling (line 219) | def process_dataset_for_rejected_sampling(dataset):
function load_datasets (line 250) | def load_datasets(config, tokenizer):
FILE: utils/engine.py
function initialize (line 47) | def initialize(
function unpack_accepted_rejected (line 72) | def unpack_accepted_rejected(example):
class LoadMicroBatchMultipleBuffers (line 82) | class LoadMicroBatchMultipleBuffers(PipeInstruction):
method __init__ (line 83) | def __init__(self, *buffer_ids, **kwargs):
class ReferenceLogitsForwardPass (line 87) | class ReferenceLogitsForwardPass(BufferOpInstruction):
class CustomPipelineEngine (line 91) | class CustomPipelineEngine(PipelineEngine):
method __init__ (line 92) | def __init__(
method set_dataloader (line 137) | def set_dataloader(self, loader):
method train_batch (line 144) | def train_batch(self):
method eval_batch (line 225) | def eval_batch(self, data_iter):
method sample_batch (line 275) | def sample_batch(self, prompts):
method _sample_from_iterator (line 305) | def _sample_from_iterator(self, data_iter, collate_fn):
method _aggregate_total_losses (line 335) | def _aggregate_total_losses(self):
method _exec_forward_pass (line 389) | def _exec_forward_pass(self, buffer_id):
method _exec_load_micro_batch_multiple_buffers (line 459) | def _exec_load_micro_batch_multiple_buffers(self, buffer_ids):
method _exec_reference_logits_forward_pass (line 512) | def _exec_reference_logits_forward_pass(self, buffer_id):
method _exec_send_micro_batch_id (line 570) | def _exec_send_micro_batch_id(self, send_micro_batch_id):
method _exec_load_micro_batch_for_sampling (line 593) | def _exec_load_micro_batch_for_sampling(self, buffer_id, inputs):
method _exec_sampling_forward_pass (line 603) | def _exec_sampling_forward_pass(self, buffer_id):
method _sample_from_logits (line 657) | def _sample_from_logits(self, buffer_id):
method _valid_stage (line 666) | def _valid_stage(self, stage_id):
method _valid_micro_batch (line 669) | def _valid_micro_batch(self, micro_batch_id):
method _exec_sampling_schedule (line 672) | def _exec_sampling_schedule(self, examples, feature_prefix='', max_tot...
class ColumnMajorParallelTopology (line 825) | class ColumnMajorParallelTopology(ProcessTopology):
method __init__ (line 832) | def __init__(self, num_pp, num_dp):
class CustomPipelineModule (line 837) | class CustomPipelineModule(PipelineModule):
method __init__ (line 838) | def __init__(self, layers, use_column_major_topology, model=None, **kw...
method model (line 854) | def model(self):
method set_dpo_reference_mode (line 857) | def set_dpo_reference_mode(self, dpo_reference_mode):
method set_sampling_mode (line 860) | def set_sampling_mode(self, sampling_mode):
method _partition_layers (line 863) | def _partition_layers(self, method='uniform'):
class DPOTrainSchedule (line 926) | class DPOTrainSchedule(PipeSchedule):
method steps (line 929) | def steps(self):
method num_pipe_buffers (line 986) | def num_pipe_buffers(self):
method _step_to_micro_batch (line 993) | def _step_to_micro_batch(self, step_id):
method _even_step_forward_id (line 1015) | def _even_step_forward_id(self, step_id):
method _odd_step_forward_id (line 1020) | def _odd_step_forward_id(self, step_id):
method _even_step_backward_id (line 1025) | def _even_step_backward_id(self, step_id):
method _odd_step_backward_id (line 1030) | def _odd_step_backward_id(self, step_id):
method _buffer_idx (line 1036) | def _buffer_idx(self, micro_batch_id):
class DPOInferenceSchedule (line 1041) | class DPOInferenceSchedule(PipeSchedule):
method steps (line 1042) | def steps(self):
method num_pipe_buffers (line 1089) | def num_pipe_buffers(self):
FILE: utils/hqq_utils.py
function _maybe_include_all_linear_layers (line 15) | def _maybe_include_all_linear_layers(peft_config: peft.PeftConfig, model...
class CustomHQQConfig (line 55) | class CustomHQQConfig:
method __post_init__ (line 64) | def __post_init__(self):
method use_aten (line 67) | def use_aten(self):
method get_dict (line 70) | def get_dict(self, full_name):
FILE: utils/saver.py
function need_to_checkpoint (line 20) | def need_to_checkpoint(config):
function convert_state_dict_dtype (line 38) | def convert_state_dict_dtype(state_dict, dtype):
class Saver (line 43) | class Saver:
method __init__ (line 44) | def __init__(self, model_engine, pipeline_model, train_dataloader, lor...
method save_lora (line 69) | def save_lora(self, name):
method save_full_model (line 105) | def save_full_model(self, name, max_shard_size='5GB'):
method will_save (line 153) | def will_save(self, type, name):
method save_model (line 169) | def save_model(self, name):
method save_checkpoint (line 175) | def save_checkpoint(self, step):
method process_epoch (line 187) | def process_epoch(self, epoch, step):
method process_step (line 202) | def process_step(self, step):
method append_eval_results (line 247) | def append_eval_results(self, loss, save_best=True):
method safe_rmtree (line 265) | def safe_rmtree(self, dir_path, max_retries=5, initial_wait_seconds=1):
FILE: utils/unsloth_utils.py
class Unsloth_Offloaded_Gradient_Checkpointer (line 23) | class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
method forward (line 32) | def forward(ctx, forward_function, hidden_states, *args):
method backward (line 45) | def backward(ctx, *grads):
function unsloth_checkpoint (line 70) | def unsloth_checkpoint(function, *args):
FILE: utils/utils.py
function log (line 16) | def log(msg):
function eta_str (line 20) | def eta_str(eta):
Condensed preview — 30 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (246K chars).
[
{
"path": ".gitignore",
"chars": 3078,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": ".gitmodules",
"chars": 98,
"preview": "[submodule \"axolotl\"]\n\tpath = axolotl\n\turl = https://github.com/OpenAccess-AI-Collective/axolotl/\n"
},
{
"path": "LICENSE",
"chars": 1066,
"preview": "MIT License\n\nCopyright (c) 2023 tdrussell\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\n"
},
{
"path": "README.md",
"chars": 12618,
"preview": "# qlora-pipe\nA pipeline parallel training script for LLMs.\n\nRefer to the changelog at the bottom for details on updates."
},
{
"path": "examples/alpaca.yml",
"chars": 58,
"preview": "datasets:\n - path: vicgalle/alpaca-gpt4\n type: alpaca\n"
},
{
"path": "examples/capybara.yml",
"chars": 189,
"preview": "chat_template: llama3\ndatasets:\n - path: ssmi153/Capybara-ShareGPT\n type: chat_template\n\n field_messages: convers"
},
{
"path": "examples/config.toml",
"chars": 7635,
"preview": "# Paths\nmodel = '/data2/models/Meta-Llama-3.1-8B'\noutput_dir = '/data/training_runs/llama3_8b_example'\n\n# Lora configura"
},
{
"path": "examples/config_dpo.toml",
"chars": 1044,
"preview": "# Paths\nmodel = '/data2/models/Meta-Llama-3.1-8B-Instruct'\noutput_dir = '/data/training_runs/llama3_8b_dpo_example'\n\nlor"
},
{
"path": "examples/converted_dpo_dataset.yml",
"chars": 425,
"preview": "# Some DPO datasets are not in conversation format.\n# They need to be in conversation format to load them in this script"
},
{
"path": "examples/ds_config.json",
"chars": 138,
"preview": "{\n \"train_micro_batch_size_per_gpu\": 1,\n \"gradient_accumulation_steps\": 1,\n \"gradient_clipping\": 1.0,\n \"step"
},
{
"path": "examples/ultrafeedback.yml",
"chars": 220,
"preview": "# This dataset is already in a format that can be directly loaded by the orpo.chat_template type.\nchat_template: llama3\n"
},
{
"path": "kernels/cross_entropy_loss.py",
"chars": 10647,
"preview": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License,"
},
{
"path": "kernels/utils.py",
"chars": 8422,
"preview": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License,"
},
{
"path": "models/layers.py",
"chars": 29162,
"preview": "import math\n\nimport torch\nimport torch.nn.functional as F\nimport transformers\nfrom deepspeed.runtime.pipe import module "
},
{
"path": "models/models.py",
"chars": 15201,
"preview": "import accelerate\nimport torch\nimport transformers\n\nfrom models.layers import (\n InputLayer,\n LayerSpec,\n Llama"
},
{
"path": "models/pipeline_model.py",
"chars": 12112,
"preview": "import os\nfrom collections import defaultdict\nfrom inspect import signature\nimport re\n\nimport accelerate\nimport bitsandb"
},
{
"path": "pyproject.toml",
"chars": 1161,
"preview": "[tool.black]\n# Only used by `hf-doc-builder´.\nline-length = 119\ntarget-version = ['py38']\n\n[tool.ruff]\ntarget-version = "
},
{
"path": "requirements.txt",
"chars": 341,
"preview": "torch\ntorchvision\ntorchaudio\naccelerate\nbitsandbytes\ndatasets\ndeepspeed\npackaging\npeft\nsafetensors\nscipy\nsentencepiece\nt"
},
{
"path": "tools/convert_dpo_dataset_to_chat_format.py",
"chars": 838,
"preview": "# Convert a DPO dataset with prompt, chosen, rejected fields into chat format.\n# Usage: python convert_dpo_dataset_to_ch"
},
{
"path": "tools/convert_ds_checkpoint_to_lora.py",
"chars": 1340,
"preview": "# Very hacky script to convert pipeline parallel Deepspeed checkpoints into a saved lora model.\n# I originally wrote thi"
},
{
"path": "tools/merge_lora.py",
"chars": 3442,
"preview": "# Usage: python merge_lora.py input_path lora_path output_path\n# Output path is created if it doesn't exist\n\nimport argp"
},
{
"path": "tools/test_sampling.py",
"chars": 3556,
"preview": "# deepspeed --num_gpus=1 --module tools.test_sampling --config ~/code/qlora-pipe-configs/config_8b_dpo.toml\n\nimport argp"
},
{
"path": "train.py",
"chars": 32688,
"preview": "import argparse\nimport glob\nimport itertools\nimport json\nimport os\nimport shutil\nimport time\nfrom contextlib import cont"
},
{
"path": "utils/dataloader.py",
"chars": 14647,
"preview": "import math\nimport os.path\nimport sys\n\n\nsys.path.insert(0, os.path.abspath('axolotl/src'))\n\nimport accelerate\nimport tor"
},
{
"path": "utils/dataset_utils.py",
"chars": 12321,
"preview": "import os\nimport os.path\nimport sys\n\n\nsys.path.insert(0, os.path.abspath('axolotl/src'))\n\nimport datasets\nimport torch\ni"
},
{
"path": "utils/engine.py",
"chars": 47157,
"preview": "from collections import deque\nimport time\n\nimport deepspeed\nimport torch\nimport transformers\nfrom deepspeed import comm "
},
{
"path": "utils/hqq_utils.py",
"chars": 3082,
"preview": "from dataclasses import asdict, dataclass, field\nfrom typing import Any\n\nimport torch\nimport transformers\nfrom hqq.core "
},
{
"path": "utils/saver.py",
"chars": 12386,
"preview": "import glob\nimport json\nimport os\nimport shutil\nimport sys\nimport time\nfrom pathlib import Path\n\nimport deepspeed\nimport"
},
{
"path": "utils/unsloth_utils.py",
"chars": 2533,
"preview": "# Unsloth Zoo - Utilities for Unsloth\n# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n"
},
{
"path": "utils/utils.py",
"chars": 631,
"preview": "import os.path\nimport sys\nfrom datetime import datetime\n\nimport torch\n\n\nsys.path.insert(0, os.path.abspath('axolotl/src'"
}
]
About this extraction
This page contains the full source code of the tdrussell/qlora-pipe GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 30 files (232.7 KB), approximately 55.0k tokens, and a symbol index with 222 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.