[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\n"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"axolotl\"]\n\tpath = axolotl\n\turl = https://github.com/OpenAccess-AI-Collective/axolotl/\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 tdrussell\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# qlora-pipe\nA pipeline parallel training script for LLMs.\n\nRefer to the changelog at the bottom for details on updates.\n\n## About\nThis is a training script I made so that I can fine-tune LLMs using my workstation with four 4090s. It is developed first and foremost for myself, with my own use cases in mind. It is scrappy and hacked together. It will likely *never* be a stable, well-supported training script like Axolotl. I am open sourcing the code in case it is useful to others, and also as a proof-of-concept that this kind of thing is possible.\n\nThat being said, if something doesn't work right, or you would like it to support some feature, feel free to raise an issue and I'll try to look at it.\n\n## Features\n- Pipeline parallel training, for efficiently training large models that cannot fit on one GPU\n- Supports QLoRA, LoRA, and full fine tuning\n- Quantize weights using either bitsandbytes or HQQ\n- Efficient model loading. Each process only loads the layers it needs, and quantizes and moves them to the GPU layer-by-layer. This means you can load a large model on a lot of GPUs even with limited system RAM.\n- Load any dataset that Axolotl can, using the same YAML config file format\n- Support for \"raw text\" training using either a structured list of documents in a JSON file, or a single txt file\n- Support for resuming training from a checkpoint, including the dataloader state, to easily allow training in a piecemeal fashion\n- Useful metrics logged to Tensorboard\n- Ability to specify a separate, fixed evaluation dataset\n- Train on multiple datasets simultaneously, with different sampling ratios per dataset\n- Models currently supported: Llama, Mistral, Mixtral, Qwen, Cohere (Command R), Phi-3 (mini and medium), Gemma 2, Gemma 3, Cohere2 (Command-A)\n\n## Installing\nClone the repository:\n```\ngit clone --recurse-submodules https://github.com/tdrussell/qlora-pipe\n```\n\nIf you alread cloned it and forgot to do --recurse-submodules:\n```\ngit submodule init\ngit submodule update\n```\n\nInstall Miniconda: https://docs.conda.io/en/latest/miniconda.html\n\nCreate the environment\n```\nconda create -n qlora-pipe python=3.12\nconda activate qlora-pipe\n```\n\nInstall the dependencies:\n```\npip install -r requirements.txt\n```\n\nInstall nvcc:\n```\nconda install nvidia::cuda-nvcc\n```\n\n## Training\n__Start by reading through the config files in the examples directory__. There are lots of comments explaining what the various fields do. Then, make a copy and edit it however you like. At minimum, change the paths at the top to point to your model and desired output directory. Launch the training script:\n```\nNCCL_P2P_DISABLE=\"1\" NCCL_IB_DISABLE=\"1\" deepspeed --num_gpus=1 train.py --deepspeed --config examples/config.toml\n```\nRTX 4000 series needs those 2 enviroment variables set. Other GPUs may not need them.\n\n## Parallelism\nDeepspeed handles pipeline- and data-parallelism. Set the --num_gpus flag to however many GPUs to want to use. The config option `pipeline_stages` determines the level of model parallelism. Then, the data parallelism is automatically set so that all GPUs are used.\n\nFor example with 8 GPUs, and pipeline_stages=4, a single instance of the model is divided across 4 GPUs. Because there are 8 GPUs total, there are then 2 data-parallel instances.\n\nThe option `gradient_accumulation_steps` in the Deepspeed JSON config file determines the amount of pipelining when using pipeline parallelism (pipeline_stages>1). The higher the value, the more the GPUs can overlap computation. For example, with gradient_accumulation_steps=1, there is a single batch that gets passed between the GPUs forward, then in reverse for the backward pass. Only 1 GPU is active at a time, the others are idle. As gradient_accumulation_steps increases, you start pipelining multiple forward/backward batches. At the beginning and end of the step, some GPUs will always be idle. So as gradient_accumulation_steps approaches infinity, you approach 100% theoretical utilization. In practice, a value of 8 or so already gives good average utilization with 2 GPUs. With more GPUs, you may want to go higher.\n\n## Dataset configuration\nThere are 3 options for specifying each dataset. Set the `dataset_type` field to one of:\n- axolotl\n  - Loads the dataset using the Axolotl codebase. Set `dataset_path` to a YAML file that contains the same dataset configuration you would use in Axolotl.\n- doclist\n  - Set `dataset_path` to glob pattern matching one or more JSON or JSONL files. Each file should be a list of objects containing a 'text' key. For each file, all of the text is logically concatenated together, before being sliced into sequences.\n- textfile\n  - Basically the same as doclist, except the `dataset_path` matches one or more txt files. Each text file is sliced into sequences.\n\nYou can read dataset_utils.py for details on what each of these options is doing.\n\nYou can have multiple datasets. Just add additional `[[datasets]]` entries. When using multiple datasets, there are different ways to combine them.\n- `dataset_combination_mode` = 'concatenate' (the default)\n  - Just concatenates the datasets.\n- `dataset_combination_mode` = 'interleave'\n  - Uses the Huggingface Datasets library `interleave_datasets()` function.\n  - Use the `dataset_interleave_stopping_strategy` setting to control when interleaving stops.\n    - 'first_exhausted': stop when a dataset runs out of examples.\n    - 'all_exhausted': stop when all datasets have run out of examples. This duplicates examples from smaller datasets.\n  - When using the 'interleave' mode, datasets can have a relative `sample_weight`, which is a positive real number. This controls the relative proportion of the datasets when they are combined.\n  - __IMPORTANT__: When using the 'interleave' mode, the manner in which the datasets are proportionally combined (i.e. sampled from) is affected  by the `batch_size_tokens` setting:\n    - If `batch_size_tokens` is unset, it means you are treating each example equally. Every batch has the same number of examples, even though they may be different lengths. So, when interleaving datasets, the rows are sampled according to the relative proportions given by the `sample_weight`.\n    - If using `batch_size_tokens`, it means you are treating each token equally. Every batch varies the number of examples (because they might have different lengths) so that the token count is approximately constant. So, when interleaving datasets, the sampling ratios are adjusted so that the number of *tokens*, not rows, drawn from different datasets matches the `sample_weight`. This is implemented by scaling the sampling probabilities by the average length of the dataset. You can read the `combine_datasets()` function in dataset_utils.py if this is confusing.\n    - __Which of these should I use?__ Probably set `batch_size_tokens`. I think this is the better way to think about things, and it matches what sample packing would do. For example, in Axolotl, it is recommended to use sample packing, which packs multiple examples into a single sequence so that the sequence length is constant. This means, in the loss function, each token is being treated with equal weight, not each original row in the dataset. Using `batch_size_tokens` in this training script mimics that behavior, and thus when interleaving datasets, it samples from them so that the token ratios adhere to the sample_weight specified.\n    - __Example__: you have datasets A and B. B's average row length is twice that of A. A has a sample_weight of 2, B has a sample_weight of 1.\n      - Not setting batch_size_tokens: when interleaving, you get 2 rows of A for every row of B.\n      - Using batch_size_tokens: when interleaving, you get 4 rows of A for every row of B. This is because A's rows are on average half the length of B's rows, so you need twice as many as before so that the number of tokens in each matches the 2:1 ratio you specified with the sample_weight.\n\n## On sample packing (or the lack thereof)\nSample packing is not currently implemented. Instead, there is the option `batch_size_tokens`. If this field is set, the per-device batch size is ignored, and instead the batch size is adjusted dynamically to target a fixed number of tokens per batch, per device. This was easier to implement than sample packing, and does basically the same thing. It is also efficient: if I set batch_size_tokens to a modest 10000 and train a 7B model with the Alpaca dataset, all my 4090s hit their 350W power limit cap. Unless I'm missing something (definitely possible), it seems there is no need to support sample packing.\n\n## Floating point precision\nThere are different places you can specify the floating point dtype. `model_weight_dtype` controls the precision of the underlying model weights (for any weights not quantized), and `lora_weight_dtype` is for the lora weights. If you are using quantization, both bnb and hqq have options for the compute dtype as well.\n\nIf you are using 16 bit dtypes, floating point roundoff error is a potential problem. For a good overview of the problem and solutions, see [Revisiting Bfloat16 Training](https://arxiv.org/pdf/2010.06192). TLDR: the main source of precision error when training with 16 bit weights is the weight update step: $(p = p + \\Delta p * lr)$. When the update is very small compared to the parameter (which is often the case), there can be significant roundoff error, including the update being entirely dropped. Mixed precision training solves this by keeping a master copy of the weights in fp32, and running all optimizer steps in fp32. Kahan summation is another solution when training in full bf16, that keeps an extra bf16 buffer for each parameter to accumulate roundoff errors so that updates are never dropped.\n\n### Okay but how should I configure things?\n - If unsure, set everything to bf16 and use the adamw_kahan optimizer type. Kahan summation is ESPECIALLY important for full fine tuning. Kahan summation requires an extra 2 bytes per trainable parameter compared to vanilla full bf16 training.\n - For LoRAs, another option is setting `lora_weight_dtype` to fp32, which also makes all optimizer states fp32.\n - For LoRAs only, with constant learning rate no lower than 5e-5 or so, I have seen full bf16 training with no Kahan summation mostly match fp32 or bf16 + Kahan.\n - (more experimental) You may try Deepspeed's bf16 mode, but I personally don't use this. I think this does something like mixed precision, where it wraps the optimizer to keep a master copy of the parameters in fp32, as well as doing gradient accumulation and all optimizer states in fp32. This will use much more memory than full bf16 + Kahan summation.\n\n## Changelog\n### 2025-03-12\n- Change how weights are loaded to avoid Transformers internal method\n- Support Gemma 3\n### 2025-01-30\n- Add pretokenized dataset option.\n- Update layers to work with the new way to pass position embeddings in HF Transformers. Please update Transformers to the latest version or you will get errors.\n### 2024-12-29\n- Add DPO training. The examples directory has a DPO example.\n### 2024-07-02\n- Add Gemma-2 support.\n### 2024-06-20\n- Add adamw_kahan optimzer type and make it the default in the example.\n### 2024-05-19\n**The old config file format will break.** Quantization is configured slightly differently now. Read examples/config_7b.toml. It's only a few lines to change.\n- Change how quantization is configured. Quantization is now its own table in the TOML file.\n- Add HQQ quantization.\n### 2024-04-28\n- Add llama3 instruction formatting option when loading a ShareGPT formatted dataset using Axolotl.\n- Automatically add BOS token for Llama 3.\n- Add option for Unsloth activation checkpointing, which saves VRAM for a very small hit to performance.\n### 2024-04-16\n- Optimizer is now specified in the config.toml file.\n- Can use AdamW8Bit optimizer.\n- MLP offloading works again. For MoE, can offload a specified number of experts.\n- Can have separate dtype for saved files.\n- Cohere model support (command-r)\n### 2024-04-07\nMake sure to update requirements! Axolotl does some dynamic importing, so things will break in a very hard to diagnose way if you don't have a new dependency that was added.\n- Removed the need for manually specifying cache directories for datasets. All dataset processing uses the Huggingface Datasets library and takes advantage of the automatic caching that it provides.\n- Added the ability to specify multiple datasets, with different ways to combine them. __This breaks the old config format for datasets.__ Refer to the example config for what it should look like now."
  },
  {
    "path": "examples/alpaca.yml",
    "content": "datasets:\n  - path: vicgalle/alpaca-gpt4\n    type: alpaca\n"
  },
  {
    "path": "examples/capybara.yml",
    "content": "chat_template: llama3\ndatasets:\n  - path: ssmi153/Capybara-ShareGPT\n    type: chat_template\n\n    field_messages: conversations\n    message_field_role: from\n    message_field_content: value\n"
  },
  {
    "path": "examples/config.toml",
    "content": "# Paths\nmodel = '/data2/models/Meta-Llama-3.1-8B'\noutput_dir = '/data/training_runs/llama3_8b_example'\n\n# Lora configuration\n# can use full_fine_tune=true and no quantization to train the whole model instead of a LoRA\n#full_fine_tune = true\nlora_rank = 64\nlora_alpha = 64\nlora_dropout = 0.05\n\n# Train only specific modules. This is passed to the parameter of the same name in the LoraConfig.\n# If not set, adapt all linear modules.\n# Note, this ALSO affects full fine tuning. In that case, if this is set, only weights containing one\n# of these keys as substring will have requires_grad. If not set everything is trained.\n#target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']\n\n# can specify layers to adapt with LoRA if you want\n#layers_to_transform = '16:31'\n\n# Training settings\n\nepochs = 2\nlr_scheduler = 'cosine'  # can also be 'constant'\nwarmup_steps = 100\n# Batch size of a single forward/backward pass for one GPU.\nmicro_batch_size_per_gpu = 1\n# Dynamic batch size, targeting this many tokens per batch, per device.\n# If set, completely ignores micro_batch_size_per_gpu.\n# Can be thought of as a replacement for sample packing.\nbatch_size_tokens = 5000\n# Number of pipeline parallel stages, must evenly divide the number of GPUs you launch the script with. The lower this is, the more GPU VRAM you need, but the faster the training will be.\n# A value of 1 means the model will be loaded on every GPU in full, each GPU running its own training batches.\n# A value of 2 means the model will be split into half, the halves loaded evenly across the (2, 4, 6, 8, ...) GPUs, where GPUs will work in pairs on each batch. (And so on.)\npipeline_stages = 1\n# Number of micro-batches sent through the pipeline for each training step.\n# If pipeline_stages > 1, a higher GAS means better GPU utilization due to smaller pipeline bubbles (where GPUs aren't overlapping computation).\ngradient_accumulation_steps = 4\n# Grad norm clipping.\ngradient_clipping = 1.0\n# might be useful if resuming from a checkpoint and you want to change the LR and force it to something\n#force_constant_lr = 5e-5\n# hard clamp the magnitude of the LoRA weights\n#scale_weight_norms = 1.0\n# for Mixtral, set the load balancing coefficient\n#load_balancing_loss_coef = 0.02\n\n# Eval settings\n\neval_steps = 100  # how often to run eval\neval_before_first_step = true  # do an eval before any training happens\neval_after_last_step = false # do a final eval after the training completes\n\n# Performance settings\n\nlogging_steps = 10  # how often to log in Tensorboard\nsave_steps = 200  # how often to save the model\ncheckpoint_every_n_minutes = 60  # how frequently to checkpoint training states (used for resuming training)\n# checkpoint_on_save = true  # alternative to the above, this will cause a checkpoint save every time a regular save occurs (note: this setting takes precedence over checkpoint_every_n_minutes)\n# dtype to load the underlying model weights in\nmodel_weight_dtype = 'bfloat16'\n# dtype for the LoRA weights\nlora_weight_dtype = 'bfloat16'\n# Can have the saved weights be different dtype. Don't need to set this. Could be useful for\n# training in float32 but saving with float16.\n#save_dtype = 'bfloat16'\n# Keep this number of stepXXXX (model saves) and global_stepXXX (checkpoint saves) and delete the rest\n# (this only applies to the current training session, and resumed training sessions will not touch\n# old saves)\nkeep_states = 3\n\n# sort examples by length before dividing them into batches\n# this makes all examples in a batch approximately the same length, to minimize padding\n# the batches are still shuffled after that\n# you should probably always have this set to true\ngroup_by_length = true\n\n# This can also be 'unsloth' to offload hidden states to CPU, saving potentially a lot of VRAM\n# for a minor performance hit.\n# Example: 4x4090, PCIE 3.0 16x, pipeline_stages=4, training QLoRA on Llama 3 70B with 4096 sequence length.\n# true: 75s step time, 19.7G peak per-GPU VRAM usage.\n# 'unsloth': 78s step time, 16.2G peak per-GPU VRAM usage.\nactivation_checkpointing = true\n\n# Keep MLP weights on system RAM until they are needed. Can save a ton of VRAM with a\n# moderate hit to performance. If using an MoE model, this can also be an integer, in\n# which case only that many experts are offloaded (tradeoff between VRAM and speed).\n#offload_mlp_to_cpu = true\n\n# Resume a prior run\n# if true, we attempt to resume training from the most recent directory inside output_dir (the directory names are timestamps)\n# so, to resume, just run the exact same command but set this to true first\nresume_from_checkpoint = false\n\n# Loading the optimizer states seems to cause some kind of unavoidable VRAM memory leak.\n# It's very small, only about 0.2 GB in cases I've seen. But if you are very close to the\n# limit, it can cause resuming from checkpoint to OOM. As a last resort, you can uncomment\n# this to not load the optimizer states and hopefully the resumption won't OOM.\n#load_optimizer_states = false\n\n\n# Dataset configuration\n\n# How to combine multiple datasets if you have more than one.\n# Can be 'concatenate' or 'interleave'. Will be 'concatenate' if not set.\ndataset_combination_mode = 'interleave'\n# When to stop interleaving datasets when using mode 'interleave'. Either 'first_exhausted' or 'all_exhausted'.\n# Default if not set: 'first_exhausted'\ndataset_interleave_stopping_strategy = 'all_exhausted'\n# Can set this lower than training, so we don't drop as many examples when trying to make equal-sized batches.\n# Default if not set: same as training GAS.\neval_gradient_accumulation_steps = 1\n\n# bitsandbytes 4 bit quantization. The parameters here become arguments to Transformers BitsAndBytesConfig.\n[quantization.bnb]\nload_in_4bit = true\nbnb_4bit_use_double_quant = false\nbnb_4bit_compute_dtype = 'bfloat16'\n\n# HQQ quantization. The parameters here become arguments to CustomHQQConfig.\n# [quantization.hqq]\n# nbits = 4\n# group_size = 64\n# compute_dtype = 'bfloat16'\n\n# (Optional) You can override the quant params for certain modules. This does substring matching, e.g. if 'gate_proj'\n# is a substring of the full module name, anything specified overwrites the defaults in [quantization.hqq].\n# [quantization.hqq.dynamic_config]\n# gate_proj = {nbits = 2, group_size = 16, quant_zero = true, quant_scale = true}\n# up_proj = {nbits = 2, group_size = 16, quant_zero = true, quant_scale = true}\n# down_proj = {nbits = 2, group_size = 16, quant_zero = true, quant_scale = true}\n\n[optimizer]\n# options: adamw_kahan, AdamW, AdamW8bit\ntype = 'adamw_kahan'\nlr = 5e-5\nbeta1 = 0.9\nbeta2 = 0.99\nweight_decay = 0.1\n\n[[datasets]]\n# Arbitrary name, used only for separately logging eval metrics. Will be dataset0, dataset1, etc if not set.\nname = 'alpaca'\ndataset_type = 'axolotl'\ndataset_path = 'examples/alpaca.yml'\nsequence_len = 2048\neval_size = 0.02\n# Relative sampling weight, when using combination mode 'interleave'. Will be 1 if not set.\nsample_weight = 1\n\n[[datasets]]\nname = 'capybara'\ndataset_type = 'axolotl'\ndataset_path = 'examples/capybara.yml'\nsequence_len = 2048\neval_size = 0.02\nsample_weight = 1.5\n\n# In addition to using eval_size which splits off some of the dataset, we can have completely separate datasets for eval.\n# This can be useful if you're training on raw text data, so that the eval set remains completely fixed, even if\n# you change training sequence_len, etc.\n# This is just an example, typically you wouldn't have this overlap a training dataset.\n# [[eval_datasets]]\n# name = 'capybara'\n# dataset_type = 'axolotl'\n# dataset_path = 'examples/capybara.yml'\n# sequence_len = 2048\n"
  },
  {
    "path": "examples/config_dpo.toml",
    "content": "# Paths\nmodel = '/data2/models/Meta-Llama-3.1-8B-Instruct'\noutput_dir = '/data/training_runs/llama3_8b_dpo_example'\n\nlora_rank = 64\nlora_alpha = 64\nlora_dropout = 0.05\n\nepochs = 2\nlr_scheduler = 'constant'\nwarmup_steps = 100\nbatch_size_tokens = 5000\nmicro_batch_size_per_gpu = 1\npipeline_stages = 1\ngradient_accumulation_steps = 4\ngradient_clipping = 1.0\n\neval_steps = 100\neval_before_first_step = true\neval_after_last_step = false\n\nlogging_steps = 10\nsave_steps = 200\ncheckpoint_every_n_minutes = 60\nmodel_weight_dtype = 'bfloat16'\nlora_weight_dtype = 'bfloat16'\n\ngroup_by_length = true\nactivation_checkpointing = true\n\neval_gradient_accumulation_steps = 1\n\n[rl]\nmethod = 'dpo'\ndpo_beta = 0.02\n\n# [quantization.bnb]\n# load_in_4bit = true\n# bnb_4bit_use_double_quant = false\n# bnb_4bit_compute_dtype = 'bfloat16'\n\n[optimizer]\ntype = 'adamw_kahan'\nlr = 5e-5\nbeta1 = 0.9\nbeta2 = 0.99\nweight_decay = 0.1\n\n[[datasets]]\nname = 'ultrafeedback'\ndataset_type = 'axolotl'\ndataset_path = 'examples/ultrafeedback.yml'\nsequence_len = 4096\neval_size = 0.01\n"
  },
  {
    "path": "examples/converted_dpo_dataset.yml",
    "content": "# Some DPO datasets are not in conversation format.\n# They need to be in conversation format to load them in this script. You can convert them:\n# python tools/convert_dpo_dataset_to_chat_format.py unalignment/toxic-dpo-v0.2 ~/data/toxic-dpo-v0.2-converted\nchat_template: llama3\ndatasets:\n  - path: json\n    data_files:\n      - /home/anon/data/toxic-dpo-v0.2-converted/train.json\n    split: train\n    type: orpo.chat_template\n"
  },
  {
    "path": "examples/ds_config.json",
    "content": "{\n    \"train_micro_batch_size_per_gpu\": 1,\n    \"gradient_accumulation_steps\": 1,\n    \"gradient_clipping\": 1.0,\n    \"steps_per_print\": 1\n}\n"
  },
  {
    "path": "examples/ultrafeedback.yml",
    "content": "# This dataset is already in a format that can be directly loaded by the orpo.chat_template type.\nchat_template: llama3\ndatasets:\n  - path: argilla/ultrafeedback-binarized-preferences-cleaned\n    type: orpo.chat_template"
  },
  {
    "path": "kernels/cross_entropy_loss.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport triton\nimport triton.language as tl\n\nfrom .utils import MAX_FUSED_SIZE, calculate_settings, device_warp_size\n\n\n@triton.heuristics(\n    {\n        'DO_LOGIT_SCALING': lambda args: args['DO_LOGIT_SCALING'],\n    }\n)\n@triton.jit\ndef _cross_entropy_forward(\n    logits_ptr,\n    logits_row_stride,\n    loss_ptr,\n    logsumexp_ptr,\n    labels_ptr,\n    VOCAB_SIZE: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n    DO_LOGIT_SCALING: tl.constexpr,\n    LOGIT_SCALE: tl.constexpr,\n):\n    \"\"\"\n    Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]\n    Pi = exp(xi) / sum(exp(xi))\n    CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]\n         = -y [ x - log[sum(exp(x))] ]\n         = y * (log[sum(exp(x))] - x)\n    If y == 0: CE_i = 0\n    If y == 1: CE_i = logsumexp - x\n\n    logsumexp is also stable\n    Take    y =         log[sum(exp(x))]\n       exp(y) =             sum(exp(x))\n       exp(y) =             sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x\n       exp(y) =      exp(c)*sum(exp(x - c))\n           y  = log(exp(c)*sum(exp(x - c)))\n           y  = c + log[sum(exp(x - c))]\n    This means we can set c = max(x) to make sure\n    exp(x - c) always is exp(x - max(x)).\n    This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.\n    \"\"\"\n    row_idx = tl.program_id(0)\n    logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n    loss_ptr += row_idx\n    logsumexp_ptr += row_idx\n    labels_ptr += row_idx\n\n    col_offsets = tl.arange(0, BLOCK_SIZE)\n    mask = col_offsets < VOCAB_SIZE\n\n    label_idx = tl.load(labels_ptr).to(tl.int32)\n    logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float('inf')).to(tl.float32)\n    if DO_LOGIT_SCALING:\n        # Logit scaling: s * x\n        logits = LOGIT_SCALE * logits\n    pass\n    c = tl.max(logits, 0)\n    logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))\n\n    if label_idx != -100:\n        x = tl.load(logits_ptr + label_idx).to(tl.float32)\n        if DO_LOGIT_SCALING:\n            # Logit scaling: s * x\n            x = LOGIT_SCALE * x\n        pass\n        loss = logsumexp - x\n    else:\n        loss = 0.0\n    tl.store(logsumexp_ptr, logsumexp)\n    tl.store(loss_ptr, loss)\n\n\npass\n\n\n@triton.heuristics(\n    {\n        'DO_LOGIT_SCALING': lambda args: args['DO_LOGIT_SCALING'],\n    }\n)\n@triton.jit\ndef _chunked_cross_entropy_forward(\n    logits_ptr,\n    logits_row_stride,\n    loss_ptr,\n    logsumexp_ptr,\n    labels_ptr,\n    VOCAB_SIZE: tl.constexpr,\n    N_CHUNKS: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n    DO_LOGIT_SCALING: tl.constexpr,\n    LOGIT_SCALE: tl.constexpr,\n):\n    \"\"\"\n    256K vocab divided in 4 chunks\n\n    |-65536-| |-65536-| |-65536-| |-65536-|\n    |-------| |-------| |-------| |-------|\n    |-------| |-------| |-------| |-------|\n\n    If y == 0: CE_i = 0\n    If y == 1: CE_i = logsumexp - x\n\n    Notice we can do logsumexp for each chunk and then\n    logsumexp[chunk_sum(logsumexp)] == logsumexp\n\n    chunk_sum = log[chunk_sum(logsumexp)]\n              = log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]\n              = log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]\n              = log[sum(exp(a)) + ... + sum(exp(z))]\n              = logsumexp(x)\n\n    This means we can perform a logsumexp for each chunk, then do a\n    final logsumexp reduction!\n\n    Ie do: logsumexp(chunked_logsumexp) - x\n    \"\"\"\n    row_idx = tl.program_id(0)\n    chunk_idx = tl.program_id(1)\n    logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n    loss_ptr += row_idx\n    logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx\n    labels_ptr += row_idx\n\n    col_offsets = chunk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = col_offsets < VOCAB_SIZE\n\n    label_idx = tl.load(labels_ptr).to(tl.int32)\n    logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float('inf')).to(tl.float32)\n    if DO_LOGIT_SCALING:\n        # Logit scaling: s * x\n        logits = LOGIT_SCALE * logits\n    pass\n    c = tl.max(logits, 0)\n    logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))\n\n    if chunk_idx == 0:\n        # logsumexp(chunked_logsumexp) - x\n        # Do the -x separately\n        if label_idx != -100:\n            x = tl.load(logits_ptr + label_idx).to(tl.float32)\n            if DO_LOGIT_SCALING:\n                # Logit scaling: s * x\n                x = LOGIT_SCALE * x\n            pass\n            loss = -1.0 * x\n        else:\n            loss = 0.0\n        tl.store(loss_ptr, loss)\n    pass\n    tl.store(logsumexp_ptr, logsumexp)\n\n\npass\n\n\n@triton.heuristics(\n    {\n        'DO_LOGIT_SCALING': lambda args: args['DO_LOGIT_SCALING'],\n    }\n)\n@triton.jit\ndef _cross_entropy_backward(\n    logits_ptr,\n    logits_row_stride,\n    dloss_ptr,\n    dloss_row_stride,\n    logsumexp_ptr,\n    labels_ptr,\n    VOCAB_SIZE: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n    DO_LOGIT_SCALING: tl.constexpr,\n    LOGIT_SCALE: tl.constexpr,\n):\n    \"\"\"\n    CE_i = -y log(P) = y * (log[sum(exp(x))] - x)\n    dC/dx = d/dx (y * log[sum(exp(x))] - x * y)\n\n    From https://en.wikipedia.org/wiki/LogSumExp\n    d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)\n\n    dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)\n    dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick\n    dC/dx = y * exp[x - logsumexp] - d/dx (x * y)\n\n    If y == 0: dC/dx = 0\n    If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1\n    If y == 1 and x != label: dC/dx     = exp[x - logsumexp]\n    \"\"\"\n    row_idx = tl.program_id(0)\n    block_idx = tl.program_id(1)\n\n    logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n    dloss_ptr += row_idx * dloss_row_stride\n    col_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = col_offsets < VOCAB_SIZE\n    label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)\n\n    if label_idx != -100:\n        dloss = tl.load(dloss_ptr)\n    else:\n        dloss = 0.0\n\n    x = tl.load(logits_ptr + col_offsets, mask=mask, other=-float('inf')).to(tl.float32)\n    if DO_LOGIT_SCALING:\n        # d/dx [s * x] = s\n        x = LOGIT_SCALE * x\n    pass\n    logsumexp = tl.load(logsumexp_ptr + row_idx)\n    y = tl.exp(x - logsumexp)\n    y = tl.where(\n        col_offsets == label_idx,\n        y - 1.0,  # exp(x - logsumexp) - 1\n        y,  # exp(x - logsumexp)\n    )\n\n    # If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.\n    if DO_LOGIT_SCALING:\n        # d/dx [s * x] = s\n        y = LOGIT_SCALE * y\n    pass\n    tl.store(logits_ptr + col_offsets, dloss * y, mask=mask)\n\n\npass\n\n\nclass Fast_CrossEntropyLoss(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, logits, labels, logit_scale=1.0):\n        n_rows, vocab_size = logits.shape\n\n        div, mod = divmod(vocab_size, MAX_FUSED_SIZE)\n        n_chunks = div + (mod != 0)\n        losses = torch.empty(n_rows, dtype=torch.float32, device='cuda')\n\n        if n_chunks == 1:\n            # For small vocabs <= 65336 like Llama, Mistral\n            BLOCK_SIZE, num_warps = calculate_settings(vocab_size)\n            logsumexp = torch.empty(n_rows, dtype=torch.float32, device='cuda')\n\n            _cross_entropy_forward[(n_rows,)](\n                logits,\n                logits.stride(0),\n                losses,\n                logsumexp,\n                labels,\n                VOCAB_SIZE=vocab_size,\n                BLOCK_SIZE=BLOCK_SIZE,\n                DO_LOGIT_SCALING=(logit_scale != 1.0),\n                LOGIT_SCALE=logit_scale,\n                num_warps=num_warps,\n            )\n        else:\n            # For large vocabs > 65336 like Gemma 256K\n            logsumexp = torch.empty(\n                (\n                    n_rows,\n                    n_chunks,\n                ),\n                dtype=torch.float32,\n                device='cuda',\n            )\n\n            _chunked_cross_entropy_forward[\n                (\n                    n_rows,\n                    n_chunks,\n                )\n            ](\n                logits,\n                logits.stride(0),\n                losses,\n                logsumexp,\n                labels,\n                VOCAB_SIZE=vocab_size,\n                N_CHUNKS=n_chunks,\n                BLOCK_SIZE=MAX_FUSED_SIZE,\n                DO_LOGIT_SCALING=(logit_scale != 1.0),\n                LOGIT_SCALE=logit_scale,\n                num_warps=32 if device_warp_size() < 64 else 16,\n            )\n            # logsumexp(chunked_logsumexp) - x\n            # Do the -x separately\n            logsumexp = torch.logsumexp(logsumexp, dim=1)  # Row sum\n            losses += logsumexp\n            losses.masked_fill_(labels == -100, 0)  # Don't forget to mask padding out!\n        pass\n\n        ctx.save_for_backward(logits, logsumexp, labels)\n        ctx.logit_scale = logit_scale\n        return losses\n\n    pass\n\n    @staticmethod\n    def backward(ctx, dlosses):\n        logits, logsumexp, labels = ctx.saved_tensors\n        n_rows, vocab_size = logits.shape\n\n        BLOCK_SIZE = 4096\n        div, mod = divmod(vocab_size, BLOCK_SIZE)\n        n_blocks = div + (mod != 0)\n\n        _cross_entropy_backward[\n            (\n                n_rows,\n                n_blocks,\n            )\n        ](\n            logits,\n            logits.stride(0),\n            dlosses,\n            dlosses.stride(0),\n            logsumexp,\n            labels,\n            VOCAB_SIZE=vocab_size,\n            BLOCK_SIZE=BLOCK_SIZE,\n            DO_LOGIT_SCALING=(ctx.logit_scale != 1.0),\n            LOGIT_SCALE=ctx.logit_scale,\n            num_warps=8,\n        )\n        return (\n            logits,\n            None,\n            None,\n        )\n\n    pass\n\n\npass\n\n\ndef fast_cross_entropy_loss(logits, labels, logit_scale=1.0):\n    \"\"\"\n    Arguments:\n        logits: (batch, seq_len, vocab_size)\n        labels: (batch, seq_len,)\n    Returns:\n        losses: float\n    \"\"\"\n    batch, seq_len, d = logits.shape\n    assert labels.shape == (batch, seq_len)\n\n    loss = Fast_CrossEntropyLoss.apply(\n        logits.view(batch * seq_len, d),\n        labels.view(-1),\n        logit_scale,\n    )\n    n_items = torch.count_nonzero(labels != -100)\n    return loss.sum() / n_items\n\n\npass\n"
  },
  {
    "path": "kernels/utils.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ctypes\n\nimport bitsandbytes as bnb\nimport torch\nimport triton\n\n\nMAX_FUSED_SIZE = 65536\nnext_power_of_2 = triton.next_power_of_2\n\n\ndef device_warp_size():\n    if torch.cuda.is_available() and torch.version.hip and torch.cuda.get_device_capability(0)[0] < 10:\n        return 64\n    else:\n        return 32\n\n\ndef calculate_settings(n):\n    BLOCK_SIZE = next_power_of_2(n)\n    if BLOCK_SIZE > MAX_FUSED_SIZE:\n        raise RuntimeError(\n            f'Cannot launch Triton kernel since n = {n} exceeds the maximum CUDA blocksize = {MAX_FUSED_SIZE}.'\n        )\n    warp_scalar = 32 / float(device_warp_size())\n    num_warps = 4\n    if BLOCK_SIZE >= 32768:\n        num_warps = 32 * warp_scalar\n    elif BLOCK_SIZE >= 8192:\n        num_warps = 16 * warp_scalar\n    elif BLOCK_SIZE >= 2048:\n        num_warps = 8 * warp_scalar\n    return BLOCK_SIZE, int(num_warps)\n\n\npass\n\n\nget_ptr = bnb.functional.get_ptr\n\ncdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32\ncdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4\ncdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4\ncgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16\ncgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16\n\n\ndef QUANT_STATE(W):\n    return getattr(W, 'quant_state', None)\n\n\npass\n\n\ndef get_lora_parameters(proj):\n    # For DPO or disabled adapters\n    base_layer = proj.base_layer if hasattr(proj, 'base_layer') else proj\n    W = base_layer.weight\n\n    if not hasattr(proj, 'disable_adapters') or proj.disable_adapters or proj.merged:\n        return W, QUANT_STATE(W), None, None, None\n    pass\n\n    active_adapter = proj.active_adapters[0] if hasattr(proj, 'active_adapters') else proj.active_adapter\n    A = proj.lora_A[active_adapter].weight\n    B = proj.lora_B[active_adapter].weight\n    s = proj.scaling[active_adapter]\n    return W, QUANT_STATE(W), A, B, s\n\n\npass\n\n\ndef fast_dequantize(W, quant_state=None, out=None):\n    if quant_state is None:\n        return W\n    if type(quant_state) is not list:\n        # New quant_state as a class\n        # https://github.com/TimDettmers/bitsandbytes/pull/763/files\n        absmax = quant_state.absmax\n        shape = quant_state.shape\n        dtype = quant_state.dtype\n        blocksize = quant_state.blocksize\n        offset = quant_state.offset\n        state2 = quant_state.state2\n        absmax2 = state2.absmax\n        code2 = state2.code\n        blocksize2 = state2.blocksize\n    else:\n        # Old quant_state as a list of lists\n        absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state\n        offset, state2 = compressed_stats\n        absmax2, code2, blocksize2, _, _, _, _ = state2\n    pass\n\n    # Create weight matrix\n    if out is None:\n        out = torch.empty(shape, dtype=dtype, device='cuda')\n    else:\n        assert out.shape == shape\n        assert out.dtype == dtype\n\n    # NF4 dequantization of statistics\n    n_elements_absmax = absmax.numel()\n    out_absmax = torch.empty(n_elements_absmax, dtype=torch.float32, device='cuda')\n\n    # Do dequantization\n    ptr_out_absmax = get_ptr(out_absmax)\n    cdequantize_blockwise_fp32(\n        get_ptr(code2),\n        get_ptr(absmax),\n        get_ptr(absmax2),\n        ptr_out_absmax,\n        ctypes.c_int(blocksize2),\n        ctypes.c_int(n_elements_absmax),\n    )\n    out_absmax += offset\n\n    fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else cdequantize_blockwise_bf16_nf4\n    fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), ctypes.c_int(blocksize), ctypes.c_int(out.numel()))\n\n    # Careful returning transposed data\n    is_transposed = True if W.shape[0] == 1 else False\n    return out.t() if is_transposed else out\n\n\npass\n\n\ndef fast_gemv(X, W, quant_state, out=None):\n    if quant_state is None:\n        return torch.matmul(X, W, out=out)\n    # For fast X @ W where seq_len == 1\n    # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469\n    _, q_len, hd = X.shape\n    # assert(q_len == 1)\n\n    if type(quant_state) is not list:\n        # https://github.com/TimDettmers/bitsandbytes/pull/763/files\n        absmax = quant_state.absmax\n        shape = quant_state.shape\n        dtype = quant_state.dtype\n        blocksize = quant_state.blocksize\n        stats = quant_state.code\n        offset = quant_state.offset\n        state2 = quant_state.state2\n        absmax2 = state2.absmax\n        code2 = state2.code\n        blocksize2 = state2.blocksize\n    else:\n        absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state\n        offset, state2 = compressed_stats\n        absmax2, code2, blocksize2, _, _, _, _ = state2\n    pass\n    # assert(dtype == X.dtype)\n    bout = shape[0]\n\n    if out is None:\n        out = torch.empty(\n            (\n                1,\n                1,\n                bout,\n            ),\n            dtype=dtype,\n            device='cuda',\n        )\n    # else:\n    #     assert(out.shape == (1, 1, bout,))\n    # pass\n\n    n = 1\n    m = shape[0]\n    k = shape[1]\n    lda = shape[0]\n    ldc = shape[0]\n    ldb = (hd + 1) // 2\n    m = ctypes.c_int32(m)\n    n = ctypes.c_int32(n)\n    k = ctypes.c_int32(k)\n    lda = ctypes.c_int32(lda)\n    ldb = ctypes.c_int32(ldb)\n    ldc = ctypes.c_int32(ldc)\n\n    df = torch.empty(absmax.shape, dtype=torch.float32, device='cuda')\n    cdequantize_blockwise_fp32(\n        get_ptr(code2),\n        get_ptr(absmax),\n        get_ptr(absmax2),\n        get_ptr(df),\n        ctypes.c_int(blocksize2),\n        ctypes.c_int(df.numel()),\n    )\n    df += offset\n    absmax = df\n\n    fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else cgemm_4bit_inference_naive_bf16\n\n    blocksize = ctypes.c_int32(blocksize)\n    fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), lda, ldb, ldc, blocksize)\n\n    return out\n\n\npass\n\n\ndef fast_linear_forward(proj, X, temp_lora=None, out=None):\n    W, W_quant, lora_A, lora_B, lora_S = get_lora_parameters(proj)\n    bsz, q_len, in_dim = X.shape\n    if q_len != 1:\n        return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)\n\n    if W_quant is None:\n        out = torch.matmul(X, W.t(), out=out)\n    elif bsz == 1 and q_len == 1:\n        out = fast_gemv(X, W, W_quant, out=out)\n    else:\n        W = fast_dequantize(W.t(), W_quant)\n        out = torch.matmul(X, W, out=out)\n    pass\n\n    # Add in LoRA weights\n    if lora_A is not None:\n        out_dim = out.shape[2]\n        dtype = X.dtype\n\n        if not hasattr(lora_A, '_fast_lora'):\n            lora_A._fast_lora = lora_A.to(dtype)\n            lora_B._fast_lora = lora_B.to(dtype)\n        pass\n\n        if bsz == 1:\n            out = out.view(out_dim)\n            temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out=temp_lora)\n            out.addmv_(lora_B._fast_lora, temp_lora, alpha=lora_S)\n        else:\n            out = out.view(bsz, out_dim)\n            temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out=temp_lora)\n            out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha=lora_S)\n        pass\n        out = out.view(bsz, 1, out_dim)\n    pass\n\n    return out\n\n\npass\n\n\ndef matmul_lora(X, W, W_quant, A, B, s, out=None):\n    dtype = X.dtype\n    W = fast_dequantize(W.t(), W_quant)\n\n    if X.dim() == 3:\n        batch, seq_len, d = X.shape\n        X = X.view(-1, X.shape[-1])\n        reshape = True\n    else:\n        reshape = False\n    pass\n\n    out = torch.matmul(X, W, out=out)\n    if W_quant is not None:\n        del W\n\n    if A is not None:\n        # LoRA is enabled\n        A, B = A.t(), B.t()\n        out += (X @ A.to(dtype)) @ (s * B.to(dtype))\n    pass\n\n    return out.view(batch, seq_len, -1) if reshape else out\n\n\npass\n"
  },
  {
    "path": "models/layers.py",
    "content": "import math\n\nimport torch\nimport torch.nn.functional as F\nimport transformers\nfrom deepspeed.runtime.pipe import module as ds_pipe_module\nfrom torch import nn\n\nfrom kernels.cross_entropy_loss import Fast_CrossEntropyLoss\n\n\ndef move_data_to_device(module, device):\n    non_blocking = (device != 'cpu')\n    # handle lora\n    if hasattr(module, 'base_layer'):\n        module = module.base_layer\n    # handle HQQ\n    if hasattr(module, 'W_q'):\n        orig_data = module.W_q.data\n        module.W_q.data = orig_data.to(device, non_blocking=non_blocking)\n    else:\n        orig_data = module.weight.data\n        module.weight.data = orig_data.to(device, non_blocking=non_blocking)\n    return orig_data\n\n\ndef set_data(module, data):\n    # handle lora\n    if hasattr(module, 'base_layer'):\n        module = module.base_layer\n    # handle HQQ\n    if hasattr(module, 'W_q'):\n        module.W_q.data = data\n    else:\n        module.weight.data = data\n\n\ndef move_experts_to_device(experts, device, num_experts_to_offload):\n    orig_data = []\n    for i in range(num_experts_to_offload):\n        orig_w1 = move_data_to_device(experts[i].w1, device)\n        orig_w2 = move_data_to_device(experts[i].w2, device)\n        orig_w3 = move_data_to_device(experts[i].w3, device)\n        orig_data.append((orig_w1, orig_w2, orig_w3))\n    return orig_data\n\n\ndef set_experts_data(experts, orig_data):\n    for i, (orig_w1, orig_w2, orig_w3) in enumerate(orig_data):\n        set_data(experts[i].w1, orig_w1)\n        set_data(experts[i].w2, orig_w2)\n        set_data(experts[i].w3, orig_w3)\n\n\ndef entropy_fn(logits):\n    result = []\n    # There is a very wide range of chuck sizes that cause no increase in memory reported by\n    # nvidia-smi (Torch re-using blocks of memory?). If you try to compute it as one tensor,\n    # memory usage is huge. Chuck size of 128 seems good enough for now.\n    for logits_chuck in torch.split(logits, 128):\n        result.append(torch.distributions.Categorical(logits=logits_chuck).entropy())\n    return torch.cat(result).float()\n\n\ndef top_k_accuracy(logits, labels, k_list, ignore_index=-100):\n    keep = labels != ignore_index\n    labels = labels[keep].view(-1, 1)\n    max_k = max(k_list)\n    _, top_k_predictions = torch.topk(logits, max_k, dim=-1, sorted=True)\n    top_k_predictions = top_k_predictions[keep]\n    accuracies = []\n    for k in k_list:\n        accuracies.append(torch.any(top_k_predictions[:, :k] == labels, dim=-1).to(torch.float32).mean())\n    return accuracies\n\n\nclass LayerSpec(ds_pipe_module.LayerSpec):\n    def __init__(self, typename, *module_args, **module_kwargs):\n        super().__init__(typename, *module_args, **module_kwargs)\n\n    def build(self):\n        self.module_kwargs.pop('_estimated_size', None)\n        return self.typename(*self.module_args, **self.module_kwargs)\n\n    @property\n    def estimated_size(self):\n        return self.module_kwargs.get('_estimated_size', 1)\n\n\n# TODO: consider using Liger-Kernel fused loss implementations. The inputs are already set up to support this.\n# Would save VRAM, but some metrics could no longer be computed (e.g. entropy, accuracies).\nclass OutputLayer(nn.Module):\n    def __init__(\n        self,\n        pipeline_model,\n        loader_util,\n        lm_head,\n        logit_scale=1.0,\n        loss_type='cross_entropy_loss',\n        focal_loss_gamma=0,\n        tie_weights=None,\n        logit_softcapping=None,\n    ):\n        super().__init__()\n        # Assign list to prevent registering the nn.Module\n        self.pipeline_model = [pipeline_model]\n        # Unlike the other wrapper classes, this is called lm_head and not orig. Because this is directly a\n        # nn.Linear layer, it needs to keep the same attribute name so quantization knows not to quantize it.\n        self.lm_head = lm_head\n        self.logit_scale = logit_scale\n        self.loss_type = loss_type.lower()\n        self.focal_loss_gamma = focal_loss_gamma\n        if tie_weights:\n            self.lm_head.weight.original_name = tie_weights\n        self.logit_softcapping = logit_softcapping\n        loader_util.load_state_dict_into_module(self)\n\n        if self.loss_type == 'cross_entropy_loss' and self.focal_loss_gamma != 0:\n            raise ValueError(\"focal_loss_gamma can't be used with 'cross_entropy_loss' function\")\n\n    def forward(self, inputs):\n        hidden_states, labels = inputs\n        if self.pipeline_model[0].sampling_mode:\n            # When sampling only compute the last logits.\n            hidden_states = hidden_states[:, -1:, :]\n        labels = labels.to(hidden_states.device)\n        if self.logit_scale == 1.0:\n            logits = self.lm_head(hidden_states)\n        else:\n            logits = self.lm_head(self.logit_scale * hidden_states)\n        if self.logit_softcapping is not None and self.logit_softcapping > 0:\n            logits = logits / self.logit_softcapping\n            logits = torch.tanh(logits)\n            logits = logits * self.logit_softcapping\n\n        if self.pipeline_model[0].sampling_mode:\n            return logits\n\n        extra_ignored_labels = torch.full((labels.shape[0], 1), -100, device=logits.device)\n        labels = torch.hstack((labels[..., 1:], extra_ignored_labels))\n        # Flatten the tokens\n        vocab_size = logits.size(-1)\n        flat_logits = logits.view(-1, vocab_size)\n        flat_labels = labels.view(-1)\n        flat_loss_mask = flat_labels >= 0\n\n        cross_entropy_loss = Fast_CrossEntropyLoss.apply(flat_logits, flat_labels)\n\n        loss = None\n        if self.loss_type == 'cross_entropy_loss':\n            cross_entropy_loss = cross_entropy_loss[flat_loss_mask]\n            loss_unreduced = cross_entropy_loss\n        elif self.loss_type == 'focal_loss':\n            cross_entropy_loss = cross_entropy_loss[flat_loss_mask]\n            # See https://arxiv.org/abs/1708.02002 (Section 3)\n            p = torch.exp(-cross_entropy_loss)\n            loss_unreduced = (1 - p) ** self.focal_loss_gamma * cross_entropy_loss\n        elif self.loss_type == 'focal_loss_star':\n            cross_entropy_loss = cross_entropy_loss[flat_loss_mask]\n            # See https://arxiv.org/abs/1708.02002 (Appendix A/B)\n            # NOTE: The use of Beta makes no sense for the multinomial case as it's invariant to translation\n            loss_unreduced = Fast_CrossEntropyLoss.apply(flat_logits, flat_labels, self.focal_loss_gamma)\n            loss_unreduced = loss_unreduced[flat_loss_mask]\n            loss_unreduced = loss_unreduced / self.focal_loss_gamma\n        elif self.loss_type == 'inverse_focal_loss':\n            cross_entropy_loss = cross_entropy_loss[flat_loss_mask]\n            # See \"Rethinking Calibration of Deep Neural Networks: Do Not Be Afraid of Overconfidence\" (Section 5.2)\n            # NOTE: The alternative of p^gamma (instead of (1+p)^gamma) might be useful for gradient ascent...\n            p = torch.exp(-cross_entropy_loss)\n            loss_unreduced = (1 + p) ** self.focal_loss_gamma * cross_entropy_loss\n        elif self.loss_type == 'exponentiated_cross_entropy_loss':\n            cross_entropy_loss = cross_entropy_loss[flat_loss_mask]\n            # See \"Gradient as a Foundation for Building a Loss Function\" (Section III.B)\n            # NOTE: This is a generalisation of their \"Quadratic Cross-Entropy\" loss (QCE: gamma=2, CE: gamma=1, etc).\n            loss_unreduced = cross_entropy_loss**self.focal_loss_gamma / self.focal_loss_gamma\n        elif self.loss_type == 'dpo':\n            rl_config = self.pipeline_model[0].train_config['rl']\n            cross_entropy_loss = cross_entropy_loss.view_as(labels)  # unflatten\n            loss_mask = labels >= 0\n            logps = -(cross_entropy_loss * loss_mask).sum(-1)\n            half = cross_entropy_loss.size(0) // 2\n            chosen_logps = logps[:half]\n            rejected_logps = logps[half:]\n\n            if self.pipeline_model[0].dpo_reference_mode:\n                self.reference_chosen_logps = chosen_logps.detach()\n                self.reference_rejected_logps = rejected_logps.detach()\n                return torch.tensor(0.0, device=logits.device)\n\n            # log the language modeling loss metrics on the chosen completion\n            cross_entropy_loss = cross_entropy_loss[:half].flatten()[loss_mask[:half].flatten()]\n            hidden_states = hidden_states[:half]\n            loss_unreduced = cross_entropy_loss\n            flat_logits = logits[:half].view(-1, vocab_size)\n            flat_labels = labels[:half].view(-1)\n            flat_loss_mask = flat_labels >= 0\n\n            policy_chosen_logps = chosen_logps\n            policy_rejected_logps = rejected_logps\n            pi_logratios = policy_chosen_logps - policy_rejected_logps\n            ref_logratios = self.reference_chosen_logps - self.reference_rejected_logps\n            del self.reference_chosen_logps\n            del self.reference_rejected_logps\n            dpo_logits = pi_logratios - ref_logratios\n            loss = -F.logsigmoid(rl_config['dpo_beta'] * dpo_logits).mean()\n        else:\n            raise NotImplementedError(self.loss_type)\n\n        with torch.no_grad():\n            log_vocab_size = math.log(logits.size(-1))\n            entropy = entropy_fn(flat_logits)[flat_loss_mask]\n            # Compute normalised entropy so we can compare between models with different vocab sizes\n            normalised_entropy = entropy / log_vocab_size\n            # Compute the (negative) log-likelihood using the original *UNADJUSTED* Cross-Entropy loss.\n            log_likelihood = cross_entropy_loss.mean()\n            # Compute McFadden's Pseudo-R² metric using log(vocab_size) as the null log-likelihood.\n            mcfaddens_pseudo_r2 = 1 - (log_likelihood / log_vocab_size)\n            accuracies = top_k_accuracy(flat_logits, flat_labels, k_list=[1, 5, 20])\n            # Compute the norms of the (pre-logit-scaled) hidden states\n            hidden_state_norms = torch.norm(hidden_states.float(), dim=-1)\n            hidden_state_norms = hidden_state_norms.view(-1)[flat_loss_mask]\n        if loss is None:\n            # Normal language modeling loss types (e.g. not DPO)\n            loss = loss_unreduced.mean()\n        loss_unreduced = loss_unreduced.detach()\n        return (\n            loss,\n            loss_unreduced,\n            hidden_state_norms,\n            entropy,\n            normalised_entropy,\n            log_likelihood,\n            mcfaddens_pseudo_r2,\n            *accuracies,\n        )\n\n\ndef load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float:\n    if isinstance(gate_logits, tuple):\n        compute_device = gate_logits[0].device\n        stacked_gate_logits = torch.stack([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)\n\n    routing_weights = torch.nn.functional.softmax(stacked_gate_logits, dim=-1)  # [num_layers, num_tokens, num_experts]\n    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)  # [num_layers, num_tokens, top_k]\n    expert_mask = torch.nn.functional.one_hot(\n        selected_experts, num_experts\n    )  # [num_layers, num_tokens, top_k, num_experts]\n    # For a given token, determine if it was routed to a given expert. Think of this as a collection of top_k-hot vectors.\n    expert_mask = torch.max(expert_mask, dim=-2).values.float()  # [num_layers, num_tokens, num_experts]\n    tokens_per_layer_and_expert = torch.mean(expert_mask, dim=-2)  # [num_layers, num_experts]\n    router_prob_per_layer_and_expert = torch.mean(routing_weights, dim=-2)  # [num_layers, num_experts]\n    return torch.mean(tokens_per_layer_and_expert * router_prob_per_layer_and_expert) * num_experts**2\n\n\nclass MixtralOutputLayer(OutputLayer):\n    def __init__(\n        self,\n        pipeline_model,\n        loader_util,\n        lm_head,\n        load_balancing_loss_coef,\n        num_experts,\n        num_experts_per_tok,\n        **kwargs,\n    ):\n        super().__init__(pipeline_model, loader_util, lm_head, **kwargs)\n        self.load_balancing_loss_coef = load_balancing_loss_coef\n        self.num_experts = num_experts\n        self.num_experts_per_tok = num_experts_per_tok\n\n    def forward(self, inputs):\n        hidden_states, labels, *router_logits = inputs\n        router_logits = tuple(router_logits)\n        outputs = super().forward((hidden_states, labels))\n        if self.pipeline_model[0].sampling_mode:\n            return outputs\n        if self.load_balancing_loss_coef is not None:\n            aux_loss = transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func(\n                router_logits, self.num_experts, self.num_experts_per_tok\n            )\n            alternate_aux_loss = load_balancing_loss_func(router_logits, self.num_experts, self.num_experts_per_tok)\n            loss = outputs[0]\n            loss += self.load_balancing_loss_coef * aux_loss\n            outputs = (loss, *outputs[1:], aux_loss, alternate_aux_loss)\n        return outputs\n\n\nclass InputLayer(nn.Module):\n    def __init__(self, model):\n        super().__init__()\n        self._model = [model]\n        self.embed_tokens = model.model.embed_tokens\n        self.rotary_emb = model.model.rotary_emb\n        self.embedding_on_cpu = not self.model.train_config['full_fine_tune']\n        self.model.loader_util.load_state_dict_into_module(self)\n\n    @property\n    def model(self):\n        return self._model[0]\n\n    def forward(self, inputs):\n        past_key_values = None\n        cache_position = None\n        use_cache = self.model.sampling_mode\n\n        input_ids, attention_mask, labels = inputs[:3]\n        device = input_ids.device\n        if self.embedding_on_cpu:\n            self.embed_tokens.to('cpu')\n            input_ids = input_ids.to('cpu')\n\n        inputs_embeds = self.embed_tokens(input_ids).to(device)\n        if use_cache:\n            past_key_values = self.model.cache\n\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        cache_position = torch.arange(\n            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n        )\n        position_ids = cache_position.unsqueeze(0)\n\n        attention_mask = self.model.model._update_causal_mask(\n            attention_mask, inputs_embeds, cache_position, past_key_values, None\n        )\n        if attention_mask is None:\n            # attention_mask can end up being None, which means use full causal attention. But with pipeline parallelism,\n            # we can only pass tensors between layers. So make it an empty tensor, which will later be detected by the layers\n            # and converted back to None. Note: this only works now because dynamic_shape=True in the pipeline engine.\n            attention_mask = torch.tensor([], device=device)\n        # Work around a very strange Deepspeed bug. The combination of PipelineModule dynamic_shape=True, attention_mask being\n        # an integer dtype, and pipeline_stages>2 causes training (but not eval) to hang. So cast to float, and cast back to int\n        # in the layer.\n        if torch.is_tensor(attention_mask) and not torch.is_floating_point(attention_mask):\n            attention_mask = attention_mask.to(inputs_embeds.dtype)\n\n        hidden_states = inputs_embeds\n        if self.model.model.config.model_type == 'gemma2':\n            normalizer = torch.tensor(self.model.model.config.hidden_size**0.5, dtype=hidden_states.dtype)\n            hidden_states = hidden_states * normalizer\n\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n\n        output = hidden_states, attention_mask, cos, sin, cache_position, labels\n        # Deepspeed requirement. Float tensors must require grad.\n        for tensor in output:\n            if torch.is_floating_point(tensor):\n                tensor.requires_grad_(True)\n        return output\n\n\nclass LlamaRMSNormPipe(nn.Module):\n    def __init__(self, loader_util, orig):\n        super().__init__()\n        self.orig = orig\n        loader_util.load_state_dict_into_module(self)\n\n    def forward(self, inputs):\n        hidden_states, _, _, _, _, labels, *router_logits = inputs\n        return self.orig(hidden_states), labels, *router_logits\n\n\nclass LlamaDecoderLayerPipe(nn.Module):\n    def __init__(self, pipeline_model, loader_util, orig):\n        super().__init__()\n        self.pipeline_model = [pipeline_model]\n        self.orig = orig\n        self.mlp_offloaded_to_cpu = False\n        self.attn_implementation = pipeline_model.config._attn_implementation\n        loader_util.load_state_dict_into_module(self)\n\n    # A note on MLP offloading:\n    # We take advantage of how activation checkpointing works with reentrant checkpointing functions.\n    # During the forward pass, if gradients are disabled (eval or first forward pass of activation checkpointing)\n    # we offload the weights back to CPU at the end of the function. If gradients are enabled (second forward pass\n    # of activation checkpointing) we leave the weights on GPU, and use a backward hook to offload to CPU after the\n    # backward pass of this function is completed. This way the weights stay on the GPU for the backward pass.\n    def forward(self, inputs):\n        def move_mlp_to_cpu_hook(grad):\n            self.move_mlp_to_cpu()\n            return None\n\n        hidden_states, attention_mask, cos, sin, cache_position, labels = inputs\n        if self.mlp_offloaded_to_cpu:\n            if hidden_states.requires_grad:\n                hidden_states.register_hook(move_mlp_to_cpu_hook)\n            self.move_mlp_to_device(hidden_states.device)\n        kwargs = {}\n        if attention_mask.numel() == 0:\n            # We can't pass None between pipeline layers, so this signals that attention_mask should be None.\n            kwargs['attention_mask'] = None\n        else:\n            # We have to pass attention mask between layers as float dtype, because in certain cases training hangs otherwise. So\n            # now cast it back to int64 if we are using flash_attention_2.\n            kwargs['attention_mask'] = attention_mask.to(torch.int64) if self.attn_implementation == 'flash_attention_2' else attention_mask\n        kwargs['position_embeddings'] = (cos, sin)\n        if self.pipeline_model[0].sampling_mode:\n            kwargs['use_cache'] = True\n            kwargs['past_key_value'] = self.pipeline_model[0].cache\n        kwargs['cache_position'] = cache_position\n        result = (self.orig(hidden_states, **kwargs)[0], attention_mask, cos, sin, cache_position, labels)\n        if self.mlp_offloaded_to_cpu and not torch.is_grad_enabled():\n            self.move_mlp_to_cpu()\n        return result\n\n    def move_mlp_to_cpu(self):\n        # If it's already been moved to CPU once, just set the data to avoid a transfer.\n        if self.mlp_offloaded_to_cpu:\n            set_data(self.orig.mlp.up_proj, self.cpu_up_proj)\n            set_data(self.orig.mlp.down_proj, self.cpu_down_proj)\n            set_data(self.orig.mlp.gate_proj, self.cpu_gate_proj)\n            return\n\n        move_data_to_device(self.orig.mlp.up_proj, 'cpu')\n        move_data_to_device(self.orig.mlp.down_proj, 'cpu')\n        move_data_to_device(self.orig.mlp.gate_proj, 'cpu')\n        self.mlp_offloaded_to_cpu = True\n\n    def move_mlp_to_device(self, device):\n        self.cpu_up_proj = move_data_to_device(self.orig.mlp.up_proj, device)\n        self.cpu_down_proj = move_data_to_device(self.orig.mlp.down_proj, device)\n        self.cpu_gate_proj = move_data_to_device(self.orig.mlp.gate_proj, device)\n\n\nclass Phi3DecoderLayerPipe(LlamaDecoderLayerPipe):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def move_mlp_to_cpu(self):\n        if self.mlp_offloaded_to_cpu:\n            set_data(self.orig.mlp.gate_up_proj, self.cpu_gate_up_proj)\n            set_data(self.orig.mlp.down_proj, self.cpu_down_proj)\n            return\n\n        move_data_to_device(self.orig.mlp.gate_up_proj, 'cpu')\n        move_data_to_device(self.orig.mlp.down_proj, 'cpu')\n        self.mlp_offloaded_to_cpu = True\n\n    def move_mlp_to_device(self, device):\n        self.cpu_gate_up_proj = move_data_to_device(self.orig.mlp.gate_up_proj, device)\n        self.cpu_down_proj = move_data_to_device(self.orig.mlp.down_proj, device)\n\n\nclass MixtralDecoderLayerPipe(LlamaDecoderLayerPipe):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.num_experts_to_offload = self.pipeline_model[0].num_experts_to_offload\n\n    def forward(self, inputs):\n        def move_mlp_to_cpu_hook(grad):\n            self.move_mlp_to_cpu()\n            return None\n\n        hidden_states, attention_mask, cos, sin, cache_position, labels, *input_router_logits = inputs\n        if self.mlp_offloaded_to_cpu:\n            if hidden_states.requires_grad:\n                hidden_states.register_hook(move_mlp_to_cpu_hook)\n            self.move_mlp_to_device(hidden_states.device)\n        kwargs = {}\n        if attention_mask.numel() == 0:\n            # We can't pass None between pipeline layers, so this signals that attention_mask should be None.\n            kwargs['attention_mask'] = None\n        else:\n            kwargs['attention_mask'] = attention_mask.to(torch.int64) if self.attn_implementation == 'flash_attention_2' else attention_mask\n        kwargs['position_embeddings'] = (cos, sin)\n        if self.pipeline_model[0].sampling_mode:\n            kwargs['use_cache'] = True\n            kwargs['past_key_value'] = self.pipeline_model[0].cache\n        hidden_states, router_logits = self.orig(hidden_states, output_router_logits=True, **kwargs)\n        # TODO: fix unsloth gradient checkpointing when we return router logits\n        # router_logits = router_logits.to(torch.float32)\n        # router_logits = input_router_logits + (router_logits,)\n        # result = (hidden_states, attention_mask, cos, sin, labels, *router_logits)\n        result = (hidden_states, attention_mask, cos, sin, cache_position, labels)\n        if self.mlp_offloaded_to_cpu and not torch.is_grad_enabled():\n            self.move_mlp_to_cpu()\n        return result\n\n    def move_mlp_to_cpu(self):\n        if self.mlp_offloaded_to_cpu:\n            set_experts_data(self.orig.block_sparse_moe.experts, self.orig_data)\n            return\n\n        move_experts_to_device(self.orig.block_sparse_moe.experts, 'cpu', self.num_experts_to_offload)\n        self.mlp_offloaded_to_cpu = True\n\n    def move_mlp_to_device(self, device):\n        self.orig_data = move_experts_to_device(\n            self.orig.block_sparse_moe.experts, device, self.num_experts_to_offload\n        )\n\n\nclass Gemma3InputLayer(nn.Module):\n    def __init__(self, model):\n        super().__init__()\n        self._model = [model]\n        self.embed_tokens = model.model.embed_tokens\n        self.rotary_emb = model.model.rotary_emb\n        self.rotary_emb_local = model.model.rotary_emb_local\n        self.embedding_on_cpu = not self.model.train_config['full_fine_tune']\n        self.model.loader_util.load_state_dict_into_module(self)\n\n    @property\n    def model(self):\n        return self._model[0]\n\n    def forward(self, inputs):\n        past_key_values = None\n        cache_position = None\n        use_cache = self.model.sampling_mode\n\n        input_ids, attention_mask, labels = inputs[:3]\n        device = input_ids.device\n        if self.embedding_on_cpu:\n            self.embed_tokens.to('cpu')\n            input_ids = input_ids.to('cpu')\n\n        inputs_embeds = self.embed_tokens(input_ids).to(device)\n        if use_cache:\n            past_key_values = self.model.cache\n\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        cache_position = torch.arange(\n            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n        )\n        position_ids = cache_position.unsqueeze(0)\n\n        attention_mask = self.model.model._update_causal_mask(\n            attention_mask, inputs_embeds, cache_position, past_key_values, None\n        )\n        if attention_mask is None:\n            # attention_mask can end up being None, which means use full causal attention. But with pipeline parallelism,\n            # we can only pass tensors between layers. So make it an empty tensor, which will later be detected by the layers\n            # and converted back to None. Note: this only works now because dynamic_shape=True in the pipeline engine.\n            attention_mask = torch.tensor([], device=device)\n        # Work around a very strange Deepspeed bug. The combination of PipelineModule dynamic_shape=True, attention_mask being\n        # an integer dtype, and pipeline_stages>2 causes training (but not eval) to hang. So cast to float, and cast back to int\n        # in the layer.\n        if torch.is_tensor(attention_mask) and not torch.is_floating_point(attention_mask):\n            attention_mask = attention_mask.to(inputs_embeds.dtype)\n\n        hidden_states = inputs_embeds\n        if self.model.model.config.model_type == 'gemma2':\n            normalizer = torch.tensor(self.model.model.config.hidden_size**0.5, dtype=hidden_states.dtype)\n            hidden_states = hidden_states * normalizer\n\n        position_embeddings_global_cos, position_embeddings_global_sin = self.rotary_emb(hidden_states, position_ids)\n        position_embeddings_local_cos, position_embeddings_local_sin = self.rotary_emb_local(hidden_states, position_ids)\n\n        output = hidden_states, attention_mask, position_embeddings_global_cos, position_embeddings_global_sin, position_embeddings_local_cos, position_embeddings_local_sin, cache_position, labels\n        # Deepspeed requirement. Float tensors must require grad.\n        for tensor in output:\n            if torch.is_floating_point(tensor):\n                tensor.requires_grad_(True)\n        return output\n\n\nclass Gemma3DecoderLayerPipe(nn.Module):\n    def __init__(self, pipeline_model, loader_util, orig):\n        super().__init__()\n        self.pipeline_model = [pipeline_model]\n        self.orig = orig\n        self.mlp_offloaded_to_cpu = False\n        self.attn_implementation = pipeline_model.config._attn_implementation\n        loader_util.load_state_dict_into_module(self)\n\n    def forward(self, inputs):\n        def move_mlp_to_cpu_hook(grad):\n            self.move_mlp_to_cpu()\n            return None\n\n        hidden_states, attention_mask, position_embeddings_global_cos, position_embeddings_global_sin, position_embeddings_local_cos, position_embeddings_local_sin, cache_position, labels = inputs\n        if self.mlp_offloaded_to_cpu:\n            if hidden_states.requires_grad:\n                hidden_states.register_hook(move_mlp_to_cpu_hook)\n            self.move_mlp_to_device(hidden_states.device)\n        kwargs = {}\n        if attention_mask.numel() == 0:\n            # We can't pass None between pipeline layers, so this signals that attention_mask should be None.\n            kwargs['attention_mask'] = None\n        else:\n            kwargs['attention_mask'] = attention_mask.to(torch.int64) if self.attn_implementation == 'flash_attention_2' else attention_mask\n        kwargs['position_embeddings_global'] = (position_embeddings_global_cos, position_embeddings_global_sin)\n        kwargs['position_embeddings_local'] = (position_embeddings_local_cos, position_embeddings_local_sin)\n        kwargs['cache_position'] = cache_position\n        if self.pipeline_model[0].sampling_mode:\n            kwargs['use_cache'] = True\n            kwargs['past_key_value'] = self.pipeline_model[0].cache\n        result = (\n            self.orig(hidden_states, **kwargs)[0],\n            attention_mask,\n            position_embeddings_global_cos,\n            position_embeddings_global_sin,\n            position_embeddings_local_cos,\n            position_embeddings_local_sin,\n            cache_position,\n            labels\n        )\n        if self.mlp_offloaded_to_cpu and not torch.is_grad_enabled():\n            self.move_mlp_to_cpu()\n        return result\n\n    def move_mlp_to_cpu(self):\n        # If it's already been moved to CPU once, just set the data to avoid a transfer.\n        if self.mlp_offloaded_to_cpu:\n            set_data(self.orig.mlp.up_proj, self.cpu_up_proj)\n            set_data(self.orig.mlp.down_proj, self.cpu_down_proj)\n            set_data(self.orig.mlp.gate_proj, self.cpu_gate_proj)\n            return\n\n        move_data_to_device(self.orig.mlp.up_proj, 'cpu')\n        move_data_to_device(self.orig.mlp.down_proj, 'cpu')\n        move_data_to_device(self.orig.mlp.gate_proj, 'cpu')\n        self.mlp_offloaded_to_cpu = True\n\n    def move_mlp_to_device(self, device):\n        self.cpu_up_proj = move_data_to_device(self.orig.mlp.up_proj, device)\n        self.cpu_down_proj = move_data_to_device(self.orig.mlp.down_proj, device)\n        self.cpu_gate_proj = move_data_to_device(self.orig.mlp.gate_proj, device)\n\n\nclass Gemma3RMSNormPipe(nn.Module):\n    def __init__(self, loader_util, orig):\n        super().__init__()\n        self.orig = orig\n        loader_util.load_state_dict_into_module(self)\n\n    def forward(self, inputs):\n        hidden_states, *_, labels = inputs\n        return self.orig(hidden_states), labels\n"
  },
  {
    "path": "models/models.py",
    "content": "import accelerate\nimport torch\nimport transformers\n\nfrom models.layers import (\n    InputLayer,\n    LayerSpec,\n    LlamaDecoderLayerPipe,\n    LlamaRMSNormPipe,\n    MixtralDecoderLayerPipe,\n    MixtralOutputLayer,\n    OutputLayer,\n    Phi3DecoderLayerPipe,\n    Gemma3InputLayer,\n    Gemma3DecoderLayerPipe,\n    Gemma3RMSNormPipe,\n)\nfrom models.pipeline_model import PipelineModel\nfrom utils.utils import DTYPE_MAP\n\nDEFAULT_ATTN_IMPLEMENTATION = 'flash_attention_2'\n\n\n# A little bit of inheritance and MRO trickery since LlamaForCausalLM.__init__ only takes a\n# positional argument. We inherit PipelineModel first, but call LlamaForCausalLM init first,\n# and make sure PipelineModel doesn't have a super().__init__() call.\nclass LlamaForCausalLMPipe(PipelineModel, transformers.LlamaForCausalLM):\n    def __init__(self, config, quantization_config):\n        model_config = transformers.LlamaConfig.from_pretrained(config['model'])\n        model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)\n        torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])\n        with accelerate.init_empty_weights():\n            transformers.LlamaForCausalLM.__init__(self, model_config)\n            PipelineModel.__init__(self, config, quantization_config, model_config)\n        torch.set_default_dtype(torch.float32)\n\n    def to_layer_specs(self):\n        embedding_relative_size = 4\n        embedding_on_cpu = not self.train_config['full_fine_tune']\n        result = [LayerSpec(InputLayer, self, _estimated_size=0 if embedding_on_cpu else embedding_relative_size)]\n        for block in self.model.layers:\n            result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block))\n        result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))\n        result.append(\n            LayerSpec(\n                OutputLayer,\n                self,\n                self.loader_util,\n                self.lm_head,\n                loss_type=self.loss_type,\n                focal_loss_gamma=self.focal_loss_gamma,\n                tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None,\n                _estimated_size=embedding_relative_size,\n            )\n        )\n        return result\n\n\nclass Qwen2ForCausalLMPipe(PipelineModel, transformers.Qwen2ForCausalLM):\n    def __init__(self, config, quantization_config):\n        model_config = transformers.Qwen2Config.from_pretrained(config['model'])\n        model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)\n        torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])\n        with accelerate.init_empty_weights():\n            transformers.Qwen2ForCausalLM.__init__(self, model_config)\n            PipelineModel.__init__(self, config, quantization_config, model_config)\n        torch.set_default_dtype(torch.float32)\n\n    def to_layer_specs(self):\n        result = [LayerSpec(InputLayer, self)]\n        for block in self.model.layers:\n            result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block))\n        result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))\n        result.append(\n            LayerSpec(\n                OutputLayer,\n                self,\n                self.loader_util,\n                self.lm_head,\n                loss_type=self.loss_type,\n                focal_loss_gamma=self.focal_loss_gamma,\n                tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None,\n            )\n        )\n        return result\n\n\nclass CohereForCausalLMPipe(PipelineModel, transformers.CohereForCausalLM):\n    def __init__(self, config, quantization_config):\n        model_config = transformers.CohereConfig.from_pretrained(config['model'])\n        model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)\n        torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])\n        with accelerate.init_empty_weights():\n            transformers.CohereForCausalLM.__init__(self, model_config)\n            PipelineModel.__init__(self, config, quantization_config, model_config)\n        torch.set_default_dtype(torch.float32)\n\n    def to_layer_specs(self):\n        # the embedding table for this model is huge; load balance it better with some heuristics\n        embedding_relative_size = 4\n        embedding_on_cpu = not self.train_config['full_fine_tune']\n        result = [LayerSpec(InputLayer, self, _estimated_size=1 if embedding_on_cpu else embedding_relative_size)]\n        for block in self.model.layers:\n            result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block))\n        result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))\n        result.append(\n            LayerSpec(\n                OutputLayer,\n                self,\n                self.loader_util,\n                self.lm_head,\n                logit_scale=self.logit_scale,\n                loss_type=self.loss_type,\n                focal_loss_gamma=self.focal_loss_gamma,\n                tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None,\n                _estimated_size=embedding_relative_size,\n            )\n        )\n        return result\n\n\nclass Phi3ForCausalLMPipe(PipelineModel, transformers.Phi3ForCausalLM):\n    def __init__(self, config, quantization_config):\n        model_config = transformers.Phi3Config.from_pretrained(config['model'])\n        model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)\n        torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])\n        with accelerate.init_empty_weights():\n            transformers.Phi3ForCausalLM.__init__(self, model_config)\n            PipelineModel.__init__(self, config, quantization_config, model_config)\n        torch.set_default_dtype(torch.float32)\n\n    def to_layer_specs(self):\n        result = [LayerSpec(InputLayer, self)]\n        for block in self.model.layers:\n            result.append(LayerSpec(Phi3DecoderLayerPipe, self.loader_util, block))\n        result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))\n        result.append(\n            LayerSpec(\n                OutputLayer,\n                self,\n                self.loader_util,\n                self.lm_head,\n                loss_type=self.loss_type,\n                focal_loss_gamma=self.focal_loss_gamma,\n            )\n        )\n        return result\n\n\nclass Gemma2ForCausalLMPipe(PipelineModel, transformers.Gemma2ForCausalLM):\n    def __init__(self, config, quantization_config):\n        model_config = transformers.Gemma2Config.from_pretrained(config['model'])\n        model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)\n        torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])\n        with accelerate.init_empty_weights():\n            transformers.Gemma2ForCausalLM.__init__(self, model_config)\n            PipelineModel.__init__(self, config, quantization_config, model_config)\n        torch.set_default_dtype(torch.float32)\n\n    def to_layer_specs(self):\n        # the embedding table for this model is huge; load balance it better with some heuristics\n        # this value optimized for LoRA, pipeline_stages=2\n        embedding_relative_size = 8\n        embedding_on_cpu = not self.train_config['full_fine_tune']\n        result = [LayerSpec(InputLayer, self, _estimated_size=1 if embedding_on_cpu else embedding_relative_size)]\n        for block in self.model.layers:\n            result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block))\n        result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))\n        result.append(\n            LayerSpec(\n                OutputLayer,\n                self,\n                self.loader_util,\n                self.lm_head,\n                loss_type=self.loss_type,\n                focal_loss_gamma=self.focal_loss_gamma,\n                tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None,\n                logit_softcapping=self.config.final_logit_softcapping,\n                _estimated_size=embedding_relative_size,\n            )\n        )\n        return result\n\n\nclass MistralForCausalLMPipe(PipelineModel, transformers.MistralForCausalLM):\n    def __init__(self, config, quantization_config):\n        model_config = transformers.MistralConfig.from_pretrained(config['model'])\n        model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)\n        torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])\n        with accelerate.init_empty_weights():\n            transformers.MistralForCausalLM.__init__(self, model_config)\n            PipelineModel.__init__(self, config, quantization_config, model_config)\n        torch.set_default_dtype(torch.float32)\n\n    def to_layer_specs(self):\n        result = [LayerSpec(InputLayer, self)]\n        for block in self.model.layers:\n            result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block))\n        result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))\n        result.append(\n            LayerSpec(\n                OutputLayer,\n                self,\n                self.loader_util,\n                self.lm_head,\n                loss_type=self.loss_type,\n                focal_loss_gamma=self.focal_loss_gamma,\n            )\n        )\n        return result\n\n\nclass MixtralForCausalLMPipe(PipelineModel, transformers.MixtralForCausalLM):\n    def __init__(self, config, quantization_config):\n        model_config = transformers.MixtralConfig.from_pretrained(config['model'])\n        model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)\n        torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])\n        with accelerate.init_empty_weights():\n            transformers.MixtralForCausalLM.__init__(self, model_config)\n            PipelineModel.__init__(self, config, quantization_config, model_config)\n        torch.set_default_dtype(torch.float32)\n        self.load_balancing_loss_coef = config.get('load_balancing_loss_coef', None)\n        self.num_experts_to_offload = self.num_experts\n        if 'offload_mlp_to_cpu' in config and isinstance(config['offload_mlp_to_cpu'], int):\n            self.num_experts_to_offload = config['offload_mlp_to_cpu']\n\n    def to_layer_specs(self):\n        result = [LayerSpec(InputLayer, self)]\n        for block in self.model.layers:\n            result.append(LayerSpec(MixtralDecoderLayerPipe, self, self.loader_util, block))\n        result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))\n        result.append(\n            LayerSpec(\n                MixtralOutputLayer,\n                self,\n                self.loader_util,\n                self.lm_head,\n                load_balancing_loss_coef=self.load_balancing_loss_coef,\n                num_experts=self.num_experts,\n                num_experts_per_tok=self.num_experts_per_tok,\n                loss_type=self.loss_type,\n                focal_loss_gamma=self.focal_loss_gamma,\n            )\n        )\n        return result\n\nclass Gemma3ForCausalLMPipe(PipelineModel, transformers.Gemma3ForCausalLM):\n    def __init__(self, config, quantization_config):\n        model_config = transformers.Gemma3TextConfig.from_pretrained(config['model'])\n        model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)\n        torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])\n        with accelerate.init_empty_weights():\n            transformers.Gemma3ForCausalLM.__init__(self, model_config)\n            PipelineModel.__init__(self, config, quantization_config, model_config)\n        torch.set_default_dtype(torch.float32)\n\n    def to_layer_specs(self):\n        # the embedding table for this model is huge; load balance it better with some heuristics\n        # this value optimized for LoRA, pipeline_stages=2\n        embedding_relative_size = 8\n        embedding_on_cpu = not self.train_config['full_fine_tune']\n        result = [LayerSpec(Gemma3InputLayer, self, _estimated_size=1 if embedding_on_cpu else embedding_relative_size)]\n        for block in self.model.layers:\n            result.append(LayerSpec(Gemma3DecoderLayerPipe, self, self.loader_util, block))\n        result.append(LayerSpec(Gemma3RMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))\n        result.append(\n            LayerSpec(\n                OutputLayer,\n                self,\n                self.loader_util,\n                self.lm_head,\n                loss_type=self.loss_type,\n                focal_loss_gamma=self.focal_loss_gamma,\n                tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None,\n                logit_softcapping=self.config.final_logit_softcapping,\n                _estimated_size=embedding_relative_size,\n            )\n        )\n        return result\n\n\nclass Cohere2ForCausalLMPipe(PipelineModel, transformers.Cohere2ForCausalLM):\n    def __init__(self, config, quantization_config):\n        model_config = transformers.Cohere2Config.from_pretrained(config['model'])\n        model_config._attn_implementation = config.get('attn_implementation', DEFAULT_ATTN_IMPLEMENTATION)\n        torch.set_default_dtype(DTYPE_MAP[config.get('model_weight_dtype', 'bfloat16')])\n        with accelerate.init_empty_weights():\n            transformers.Cohere2ForCausalLM.__init__(self, model_config)\n            PipelineModel.__init__(self, config, quantization_config, model_config)\n        torch.set_default_dtype(torch.float32)\n\n    def to_layer_specs(self):\n        # the embedding table for this model is huge; load balance it better with some heuristics\n        embedding_relative_size = 8\n        embedding_on_cpu = not self.train_config['full_fine_tune']\n        result = [LayerSpec(InputLayer, self, _estimated_size=2 if embedding_on_cpu else embedding_relative_size)]\n        for block in self.model.layers:\n            result.append(LayerSpec(LlamaDecoderLayerPipe, self, self.loader_util, block))\n        result.append(LayerSpec(LlamaRMSNormPipe, self.loader_util, self.model.norm, _estimated_size=0))\n        result.append(\n            LayerSpec(\n                OutputLayer,\n                self,\n                self.loader_util,\n                self.lm_head,\n                logit_scale=self.logit_scale,\n                loss_type=self.loss_type,\n                focal_loss_gamma=self.focal_loss_gamma,\n                tie_weights='model.embed_tokens.weight' if self.config.tie_word_embeddings else None,\n                _estimated_size=embedding_relative_size,\n            )\n        )\n        return result\n"
  },
  {
    "path": "models/pipeline_model.py",
    "content": "import os\nfrom collections import defaultdict\nfrom inspect import signature\nimport re\n\nimport accelerate\nimport bitsandbytes as bnb\nimport transformers\nfrom deepspeed.accelerator import get_accelerator\nfrom hqq.core import quantize as hqq_quantize\nfrom torch import nn\nfrom transformers.integrations import get_keys_to_not_convert\nfrom accelerate.utils import set_module_tensor_to_device\n\nimport utils.hqq_utils as hqq_utils\nfrom utils.utils import is_main_process\n\n\nLANGUAGE_MODEL_WEIGHT_PREFIX_REGEX = r'^language_model\\.'\n\n\nclass PipelineModel(nn.Module):\n    def __init__(self, config, quantization_config, model_config):\n        if config['full_fine_tune'] and model_config.tie_word_embeddings:\n            raise NotImplementedError('FFT is not supported for models with tied embeddings')\n        self.train_config = config\n        self.model_config = model_config\n        self.modules_to_not_quantize = get_keys_to_not_convert(self)\n        self.loader_util = LoaderUtil(config['model'], quantization_config, self.modules_to_not_quantize)\n        self.loss_type = config.get('loss_type', 'cross_entropy_loss').lower()\n        if rl_config := config.get('rl', None):\n            self.loss_type = rl_config['method']\n        self.focal_loss_gamma = config.get('focal_loss_gamma', 0)\n        if self.focal_loss_gamma > 0 and is_main_process():\n            print(f\"Optimizing using '{self.loss_type}' with gamma={self.focal_loss_gamma}\")\n        self.dpo_reference_mode = False\n        self.sampling_mode = False\n\n        for name, p in self.named_parameters():\n            p.original_name = name\n\n    # need to override this method\n    def to_layer_specs(self):\n        raise NotImplementedError()\n\n    def set_dpo_reference_mode(self, dpo_reference_mode):\n        self.dpo_reference_mode = dpo_reference_mode\n\n    def set_sampling_mode(self, sampling_mode):\n        self.sampling_mode = sampling_mode\n        # Reset cache when sampling mode is modified. This ensures it's initialized and also clears memory at the end.\n        self.cache_dict = defaultdict(transformers.DynamicCache)\n        # We could try to use static cache at some point. During early testing with relatively short sequence lengths,\n        # it was the same sampling speed as DynamicCache. Note: will need to pass cache_position in transformer layer\n        # if using StaticCache.\n        # def make_static_cache():\n        #     return transformers.StaticCache(\n        #         self.model_config,\n        #         max_batch_size=1,\n        #         max_cache_len=1024,\n        #         device='cuda',\n        #         dtype=self.dtype,\n        #     )\n        # self.cache_dict = defaultdict(make_static_cache)\n\n    def set_cache(self, micro_batch_id):\n        self.cache = self.cache_dict[micro_batch_id]\n\n\ndef _partial_module_name_match(full_name, list_to_match):\n    return any(key in full_name for key in list_to_match)\n\n\ndef _replace_with_quantized_linear(parent_modules_map, name, full_name, quantization_config):\n    if isinstance(quantization_config, transformers.BitsAndBytesConfig):\n        _replace_with_bnb_linear(parent_modules_map, name, full_name, quantization_config)\n    elif isinstance(quantization_config, hqq_utils.CustomHQQConfig):\n        _replace_with_hqq_linear(parent_modules_map, name, full_name, quantization_config)\n    else:\n        raise NotImplementedError(f'Quantization config not implemented: {quantization_config}')\n\n\ndef _replace_with_bnb_linear(parent_modules_map, name, full_name, quantization_config):\n    \"\"\"Replace a Linear layer with a BNB quantized version.\"\"\"\n    if quantization_config.llm_int8_skip_modules is not None and _partial_module_name_match(\n        full_name, quantization_config.llm_int8_skip_modules\n    ):\n        return\n    module = parent_modules_map[name]\n    with accelerate.init_empty_weights():\n        if isinstance(module, nn.Conv1d):\n            in_features, out_features = module.weight.shape\n        else:\n            in_features = module.in_features\n            out_features = module.out_features\n\n        if quantization_config.quantization_method() == 'llm_int8':\n            parent_modules_map[name] = bnb.nn.Linear8bitLt(\n                in_features,\n                out_features,\n                module.bias is not None,\n                has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,\n                threshold=quantization_config.llm_int8_threshold,\n            )\n        else:\n            extra_kwargs = (\n                {'quant_storage': quantization_config.bnb_4bit_quant_storage}\n                if 'quant_storage' in list(signature(bnb.nn.Linear4bit).parameters)\n                else {}\n            )\n            parent_modules_map[name] = bnb.nn.Linear4bit(\n                in_features,\n                out_features,\n                module.bias is not None,\n                quantization_config.bnb_4bit_compute_dtype,\n                compress_statistics=quantization_config.bnb_4bit_use_double_quant,\n                quant_type=quantization_config.bnb_4bit_quant_type,\n                **extra_kwargs,\n            )\n        # Store the module class in case we need to transpose the weight later\n        parent_modules_map[name].source_cls = type(module)\n        # Force requires grad to False to avoid unexpected errors\n        parent_modules_map[name].requires_grad_(False)\n\n\ndef _replace_with_hqq_linear(parent_modules_map, name, full_name, quantization_config):\n    \"\"\"Replace a Linear layer with a HQQ quantized version.\"\"\"\n    if _partial_module_name_match(full_name, quantization_config.skip_modules):\n        return\n    module = parent_modules_map[name]\n    quant_config_dict = quantization_config.get_dict(full_name)\n    hqq_linear = hqq_quantize.HQQLinear(\n        module,\n        quant_config=quant_config_dict,\n        compute_dtype=quantization_config.compute_dtype,\n        device=module.weight.device,\n        initialize=True,\n        del_orig=True,\n    )\n    # Quantization itself uses a decent amount of VRAM. Temporarily move each quantized parameter to the CPU as we\n    # finish, so the quant process doesn't OOM. Deepspeed will move everything to the correct device later.\n    hqq_linear.W_q.data = hqq_linear.W_q.data.to('cpu')\n    # Store the module class in case we need to transpose the weight later\n    hqq_linear.source_cls = type(module)\n    # Force requires grad to False to avoid unexpected errors\n    hqq_linear.requires_grad_(False)\n    parent_modules_map[name] = hqq_linear\n\n\n# modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/bitsandbytes.py\ndef _recursively_replace_with_quantized_linear(\n    model,\n    modules_to_not_convert=None,\n    current_key_name=None,\n    quantization_config=None,\n):\n    \"\"\"\n    Returns the converted model and a boolean that indicates if the conversion has been successful or not.\n    \"\"\"\n    for name, module in model.named_children():\n        if current_key_name is None:\n            current_key_name = []\n        current_key_name.append(name)\n\n        if (isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d)) and name not in modules_to_not_convert:\n            # Check if the current key is not in the `modules_to_not_convert`\n            current_key_name_str = '.'.join(current_key_name)\n            if not any(\n                (key + '.' in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert\n            ):\n                _replace_with_quantized_linear(model._modules, name, current_key_name_str, quantization_config)\n\n                # copy over the original_name attribute we added earlier (needed for loading weights)\n                for orig_name, orig_p in module.named_parameters():\n                    if hasattr(orig_p, 'original_name'):\n                        for new_name, new_p in model._modules[name].named_parameters():\n                            if new_name == orig_name:\n                                new_p.original_name = orig_p.original_name\n\n        if len(list(module.children())) > 0:\n            _recursively_replace_with_quantized_linear(\n                module,\n                modules_to_not_convert,\n                current_key_name,\n                quantization_config,\n            )\n        # Remove the last key for recursion\n        current_key_name.pop(-1)\n\n\nclass LoaderUtil:\n    def __init__(self, model_path, quantization_config, modules_to_not_quantize):\n        self.model_path = model_path\n        self.quantization_config = quantization_config\n        self.modules_to_not_quantize = modules_to_not_quantize\n        self.local_rank = int(os.environ.get('LOCAL_RANK', None))\n        assert self.local_rank is not None\n        self.device = get_accelerator().device_name(self.local_rank)\n\n        index_file = os.path.join(model_path, transformers.utils.SAFE_WEIGHTS_INDEX_NAME)\n        if os.path.exists(index_file):\n            checkpoint_files, checkpoint_metadata = transformers.utils.hub.get_checkpoint_shard_files(\n                model_path, index_file, local_files_only=True\n            )\n            self.checkpoint_metadata = checkpoint_metadata\n        else:\n            self.checkpoint_metadata = None\n        self.loaded_state_dict = None\n\n    def get_partial_state_dict(self, leaf_file):\n        if self.loaded_state_dict is None or leaf_file != self.loaded_state_dict[0]:\n            print(f'loading checkpoint file {leaf_file}')\n            state_dict = transformers.modeling_utils.load_state_dict(os.path.join(self.model_path, leaf_file))\n            state_dict = {re.sub(LANGUAGE_MODEL_WEIGHT_PREFIX_REGEX, '', k): v for k, v in state_dict.items()}\n            self.loaded_state_dict = (leaf_file, state_dict)\n        return self.loaded_state_dict[1]\n\n    def maybe_quantize(self, module):\n        if self.quantization_config is None:\n            return\n        modules_to_not_convert = self.modules_to_not_quantize\n        if not isinstance(modules_to_not_convert, list):\n            modules_to_not_convert = [modules_to_not_convert]\n        _recursively_replace_with_quantized_linear(\n            module, modules_to_not_convert=modules_to_not_convert, quantization_config=self.quantization_config\n        )\n        # Make sure to set this or PEFT (and probably other things) will break in strange ways.\n        # We only need this because we do the loading and quanting ourselves.\n        self.is_loaded_in_4bit = True\n\n    def load_state_dict_into_module(self, module):\n        print(f'load params into module {type(module)}')\n        if isinstance(self.quantization_config, transformers.BitsAndBytesConfig):\n            # bnb needs to replace with quantized linear before weights are loaded\n            self.maybe_quantize(module)\n        param_renaming_map = {p.original_name: new_name for new_name, p in module.named_parameters()}\n        expected_keys = [p.original_name for p in module.parameters()]\n        # If we have any extra attributes on the parameter, loading with BNB 4bit params breaks, so delete them.\n        for p in module.parameters():\n            del p.original_name\n\n        if self.checkpoint_metadata is not None:\n            weight_map = self.checkpoint_metadata['weight_map']\n            weight_map = {re.sub(LANGUAGE_MODEL_WEIGHT_PREFIX_REGEX, '', k): v for k, v in weight_map.items()}\n            needed_checkpoint_files = {weight_map[key.replace('orig.', '')] for key in expected_keys}\n        else:\n            needed_checkpoint_files = ['model.safetensors']\n\n        for checkpoint_file in needed_checkpoint_files:\n            state_dict = self.get_partial_state_dict(checkpoint_file)\n            renamed_state_dict = {param_renaming_map[k]: v for k, v in state_dict.items() if k in param_renaming_map}\n            for name, param in module.named_parameters():\n                if name in renamed_state_dict:\n                    set_module_tensor_to_device(module, name, device='cpu', value=renamed_state_dict[name])\n\n        module.to(self.device)\n        if not isinstance(self.quantization_config, transformers.BitsAndBytesConfig):\n            self.maybe_quantize(module)\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.black]\n# Only used by `hf-doc-builder´.\nline-length = 119\ntarget-version = ['py38']\n\n[tool.ruff]\ntarget-version = \"py38\"\nline-length = 119\nextend-exclude = [\"axolotl\"]\n\n[tool.ruff.format]\nquote-style = \"single\"\n\n[tool.ruff.lint]\nextend-select = [\n    \"C\", # Complexity\n    \"E\", # PEP8 errors\n    \"F\", # PEP8 formatting\n    \"I\", # Import sorting\n    \"UP\", # Pyupgrade upgrades\n    \"W\", # PEP8 warnings\n    # \"PT009\", # Pytest assertions\n]\nignore = [\n    \"C901\", # Function too complex\n    \"E501\", # Line length (handled by ruff-format)\n    \"E741\", # Allow the letter 'l'\n    \"UP007\", # X | Y style Unions\n    # \"PT009\", # self.assertEqual etc\n]\n\n[tool.ruff.lint.isort]\nlines-after-imports = 2\nknown-first-party = [\"peft\"]\n\n[tool.pytest]\ndoctest_optionflags = [\n    \"NORMALIZE_WHITESPACE\",\n    \"ELLIPSIS\",\n    \"NUMBER\",\n]\n\n[tool.pytest.ini_options]\naddopts = \"--cov=src/peft --cov-report=term-missing --durations=10\"\nmarkers = [\n    \"single_gpu_tests: tests that run on a single GPU\",\n    \"multi_gpu_tests: tests that run on multiple GPUs\",\n    \"regression: whether to run regression suite test\",\n    \"bitsandbytes: select bitsandbytes integration tests\"\n]\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch\ntorchvision\ntorchaudio\naccelerate\nbitsandbytes\ndatasets\ndeepspeed\npackaging\npeft\nsafetensors\nscipy\nsentencepiece\ntensorboard\ntoml\njsonlines\ntransformers\nflash-attn\npyyaml\ntqdm\ntriton\nhqq\ntorch-optimi\nruff\n\n# minimum Axolotl dependencies for what we use\naddict\noptimum\npynvml\nevaluate\nwandb\nnumba\ncolorama\ntrl\nmlflow\ntermcolor\nfastcore\n"
  },
  {
    "path": "tools/convert_dpo_dataset_to_chat_format.py",
    "content": "# Convert a DPO dataset with prompt, chosen, rejected fields into chat format.\n# Usage: python convert_dpo_dataset_to_chat_format.py hf_username/some_dataset path/to/output/directory\nimport os\nimport sys\nfrom pathlib import Path\n\nimport datasets\n\n\ndataset_path, converted_path = sys.argv[1:]\n\ndataset = datasets.load_dataset(dataset_path)\n\n\ndef convert(x):\n    prompt = x['prompt']\n    chosen = [{'role': 'user', 'content': prompt}, {'role': 'assistant', 'content': x['chosen']}]\n    rejected = [{'role': 'user', 'content': prompt}, {'role': 'assistant', 'content': x['rejected']}]\n    return {'chosen': chosen, 'rejected': rejected}\n\n\nnum_proc = min(64, os.cpu_count())\ndataset = dataset.map(convert, num_proc=num_proc)\nfor name, split in dataset.items():\n    filepath = Path(converted_path) / f'{name}.json'\n    split.to_json(filepath)\n"
  },
  {
    "path": "tools/convert_ds_checkpoint_to_lora.py",
    "content": "# Very hacky script to convert pipeline parallel Deepspeed checkpoints into a saved lora model.\n# I originally wrote this because I screwed up the lora model saving initially, and needed a\n# way to turn the training checkpoints into saved lora models to test them.\n\nimport os.path\nimport re\nfrom glob import glob\n\nimport torch\n\n\ndef convert_ds_checkpoint_to_lora(ds_checkpoint_dir, lora_output_dir):\n    layer_checkpoint_files = glob(os.path.join(ds_checkpoint_dir, 'layer_*-model_states.pt'))\n    combined_state_dict = {}\n    for path in layer_checkpoint_files:\n        match = re.fullmatch('layer_(.+)-model_states.pt', os.path.basename(path))\n        layer_idx = int(match.group(1)) - 2\n        state_dict = torch.load(path)\n        for name, weight in state_dict.items():\n            converted_name = name.replace('orig', f'base_model.model.model.layers.{layer_idx}').replace('.default', '')\n            combined_state_dict[converted_name] = weight\n    os.makedirs(lora_output_dir, exist_ok=True)\n    torch.save(combined_state_dict, os.path.join(lora_output_dir, 'adapter_model.bin'))\n\n\nif __name__ == '__main__':\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--input')\n    parser.add_argument('--output')\n    args = parser.parse_args()\n\n    convert_ds_checkpoint_to_lora(args.input, args.output)\n"
  },
  {
    "path": "tools/merge_lora.py",
    "content": "# Usage: python merge_lora.py input_path lora_path output_path\n# Output path is created if it doesn't exist\n\nimport argparse\nimport os\nimport re\nimport shutil\nfrom pathlib import Path\n\nimport safetensors\nimport torch\nfrom tqdm import tqdm\n\nimport peft\n\n\nparser = argparse.ArgumentParser()\nparser.add_argument('input_path', type=str, help='The path to the input directory.')\nparser.add_argument('lora_path', type=str, help='The path to the LoRA directory.')\nparser.add_argument('output_path', type=str, help='The path to the output directory.')\nparser.add_argument('--no-gpu', action='store_true', help='Use CPU for merging.')\nargs = parser.parse_args()\n\ninput_path, lora_path, output_path = Path(args.input_path), Path(args.lora_path), Path(args.output_path)\nos.makedirs(output_path, exist_ok=True)\n\nlora_config = peft.LoraConfig.from_json_file(lora_path / 'adapter_config.json')\nscale = lora_config['lora_alpha'] / lora_config['r']\n\ndevice = 'cpu' if args.no_gpu else 'cuda'\n\nprint('Loading LoRA model...')\n\n# Check if we have adapter_model.bin or adapter_model.safetensors\nif (lora_path / 'adapter_model.safetensors').exists():\n    lora_state = safetensors.torch.load_file(lora_path / 'adapter_model.safetensors')\n    if not args.no_gpu:\n        # Move mapped entries to cuda\n        for key, value in tqdm(lora_state.items()):\n            lora_state[key] = value.to('cuda')\nelse:\n    lora_state = torch.load(lora_path / 'adapter_model.bin', map_location=device)\n\n\ndef find_lora_weights(key):\n    lora_A = None\n    lora_B = None\n    for lora_key, lora_weight in lora_state.items():\n        if key.strip('.weight') in lora_key:\n            if 'lora_A' in lora_key:\n                lora_A = lora_weight\n            elif 'lora_B' in lora_key:\n                lora_B = lora_weight\n            else:\n                raise RuntimeError()\n    assert not ((lora_A is None) ^ (lora_B is None))\n    return lora_A, lora_B\n\n\nshards = []\nfor shard in input_path.glob('model*.safetensors'):\n    shards.append(shard)\n\nprint('Copying unmergable files to output')\nfor filepath in input_path.glob('*'):\n    if filepath in shards:\n        continue\n    filepath = Path(filepath)\n    if filepath.is_dir():\n        continue\n    if filepath.suffix == '.gguf':\n        # Skip unrelated stray quantizations\n        continue\n    if filepath.suffix == '.safetensors':\n        # Consolidated, possibly\n        continue\n    print(f'copying {filepath.name} to output')\n    shutil.copy(filepath, output_path)\n\nprint('Merging and copying state_dict to output')\nfound = 0\nfor shard in (pbar := tqdm(shards)):\n    tensors = {}\n    with safetensors.safe_open(shard, framework='pt', device=device) as f:\n        metadata = f.metadata()\n        for key in f.keys():\n            lora_key = re.sub(r'^language_model\\.', '', key)\n            tensor = f.get_tensor(key)\n            lora_A, lora_B = find_lora_weights(lora_key)\n            if lora_A is not None:\n                found += 1\n                pbar.set_description(f'found lora weights for {key}: {lora_A.size()}, {lora_B.size()}')\n                old_type = tensor.dtype\n                tensor = tensor.to(torch.float32)\n                tensor += scale * lora_B.to(torch.float32) @ lora_A.to(torch.float32)\n                tensor = tensor.to(old_type)\n            tensors[key] = tensor\n        safetensors.torch.save_file(tensors, output_path / shard.name, metadata=metadata)\nprint(f\"Applied LoRA to {found} tensors.\")\n"
  },
  {
    "path": "tools/test_sampling.py",
    "content": "# deepspeed --num_gpus=1 --module tools.test_sampling --config ~/code/qlora-pipe-configs/config_8b_dpo.toml\n\nimport argparse\nimport json\nimport os.path\n\nimport bitsandbytes\nimport deepspeed\nimport toml\nimport transformers\n\nimport utils.engine as engine\nfrom train import load_pipeline_model_with_lora\nfrom utils.utils import DTYPE_MAP, is_main_process\n\n\nPROMPT_FORMAT = \"\"\"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n\"\"\"\n\nPROMPTS = [\n    'Where is Popeye Village located?',\n    'What is the name of Sweden in Swedish?',\n]\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--config', help='Path to TOML configuration file.')\nparser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')\nparser = deepspeed.add_config_arguments(parser)\nargs = parser.parse_args()\n\n\nif __name__ == '__main__':\n    with open(args.config) as f:\n        config = toml.load(f)\n    config['full_fine_tune'] = True\n\n    if hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None:\n        # engine.initialize() will load deepspeed config from args\n        ds_config = None\n    else:\n        # The necessary ds_config fields are taken from the TOML config file.\n        ds_config = {\n            'train_micro_batch_size_per_gpu': config.get('micro_batch_size_per_gpu', 1),\n            'gradient_accumulation_steps': config.get('gradient_accumulation_steps', 1),\n            'gradient_clipping': config.get('gradient_clipping', 1.0),\n            'steps_per_print': config.get('steps_per_print', 1),\n        }\n\n    deepspeed.init_distributed()\n\n    with open(os.path.join(config['model'], 'config.json')) as f:\n        model_config = json.load(f)\n        model_type = model_config.get('model_type', 'llama')\n\n    tokenizer = transformers.AutoTokenizer.from_pretrained(\n        config['model'],\n        local_files_only=True,\n        model_max_length=int(1e30),\n        padding_side='left',\n    )\n    if tokenizer.pad_token is None:\n        tokenizer.pad_token = tokenizer.eos_token\n\n    # Ugly hack so we can move quantized models from GPU to CPU, and back to GPU again without triggering quantization a second time.\n    bnb_cuda_old = bitsandbytes.nn.modules.Params4bit.cuda\n\n    def bnb_cuda_hijack(self, device):\n        if getattr(self, 'already_quantized', False):\n            self.data = self.data.to(device)\n            self.quant_state.to(device)\n            return self\n        self.already_quantized = True\n        return bnb_cuda_old(self, device)\n\n    bitsandbytes.nn.modules.Params4bit.cuda = bnb_cuda_hijack\n\n    pipeline_model, lora_model, lora_config = load_pipeline_model_with_lora(config, model_type)\n\n    kwargs = {}\n    if sampling_settings := config.get('sampling', None):\n        for k, v in sampling_settings.items():\n            kwargs['sampling_' + k] = v\n    model_engine, _ = engine.initialize(\n        args=args,\n        model=pipeline_model,\n        lora_model=lora_model,\n        config=ds_config,\n        tokenizer=tokenizer,\n        **kwargs,\n    )\n    weight_dtype = DTYPE_MAP[config.get('lora_weight_dtype', config.get('model_weight_dtype', 'float32'))]\n    model_engine.communication_data_type = weight_dtype\n\n    prompts = [PROMPT_FORMAT.format(prompt) for prompt in PROMPTS]\n    #prompts = [[PROMPT_FORMAT.format(prompt) for prompt in PROMPTS]]\n    outputs = model_engine.sample_batch(prompts, max_new_tokens=500)\n    if is_main_process():\n        for text in outputs:\n            print(text)\n            print('-' * 80)\n"
  },
  {
    "path": "train.py",
    "content": "import argparse\nimport glob\nimport itertools\nimport json\nimport os\nimport shutil\nimport time\nfrom contextlib import contextmanager\nfrom datetime import datetime, timezone\n\nimport bitsandbytes\nimport deepspeed\nimport toml\nimport torch\nimport transformers\nfrom deepspeed.runtime.pipe.module import LayerSpec\nfrom hqq.core import quantize as hqq_quantize\nfrom torch.utils.tensorboard import SummaryWriter\n\nimport utils.dataloader as dataloader\nimport utils.engine as engine\nimport utils.hqq_utils as hqq_utils\nimport models.models as models\nimport utils.unsloth_utils as unsloth_utils\nfrom utils.dataset_utils import load_datasets\nfrom peft import LoraConfig, get_peft_model\nfrom peft.optimizers import create_loraplus_optimizer\nfrom utils.saver import Saver\nfrom utils.utils import DTYPE_MAP, is_main_process\n\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--config', help='Path to TOML configuration file.')\nparser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')\nparser.add_argument('--debug_dataset', type=int, help='print out this many training examples and then quit')\nparser.add_argument(\n    '--resume_from_checkpoint',\n    action='store_true',\n    default=None,\n    help='resume training from the most recent checkpoint',\n)\nparser.add_argument('--no_quantiles', action='store_true', help='suppress output of quantile metrics')\nparser = deepspeed.add_config_arguments(parser)\nargs = parser.parse_args()\n\n\ndef print_model_info(model):\n    if not is_main_process():\n        return\n    print(model)\n    for name, module in model.named_modules():\n        print(f'{type(module)}: {name}')\n        for pname, p in module.named_parameters(recurse=False):\n            print(pname)\n            print(p.dtype)\n            print(p.device)\n            print(p.requires_grad)\n            print()\n\n\ndef set_config_defaults(config):\n    config['full_fine_tune'] = config.get('full_fine_tune', False)\n    config['load_in_4bit'] = config.get('load_in_4bit', False)\n\n\ndef get_most_recent_run_dir(output_dir):\n    return sorted(glob.glob(os.path.join(output_dir, '*')))[-1]\n\n\ndef write_metrics(tb_writer, prefix, metrics, step):\n    loss = metrics[0].mean().item()\n    tb_writer.add_scalar(f'{prefix}/loss', loss, step)\n\n    if len(metrics) > 1:\n        losses = metrics[1].view(-1)\n        positive_losses = losses > 0\n        tb_writer.add_histogram(f'{prefix}/log_loss_hist', torch.log(losses[positive_losses]), step)\n        if not args.no_quantiles:\n            sorted_losses, sorted_losses_idx = torch.sort(losses)\n            quantiles = torch.tensor(\n                [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.96, 0.97, 0.98, 0.99, 0.999], dtype=torch.float32\n            ).to(losses.device)\n            quantiles_idx = [int(len(losses) * quantile) for quantile in quantiles]\n            loss_quantiles = [sorted_losses[i] for i in quantiles_idx]\n            for quantile, value in zip(quantiles, loss_quantiles):\n                tb_writer.add_scalar(f'{prefix}/loss_quantile_{quantile:.3f}', value, step)\n\n    if len(metrics) > 2:\n        hidden_norm_avg = metrics[2].mean().item()\n        tb_writer.add_scalar(f'{prefix}/hidden_norm_avg', hidden_norm_avg, step)\n        hidden_state_norms = metrics[2].view(-1)\n        tb_writer.add_histogram(f'{prefix}/hidden_norm_hist', hidden_state_norms, step)\n\n    if len(metrics) > 3:\n        entropy = metrics[3].view(-1)\n        tb_writer.add_scalar(f'{prefix}/entropy', entropy.mean().item(), step)\n        if not args.no_quantiles:\n            assert entropy.size() == losses.size(), (entropy.size(), losses.size())\n            sorted_entropy = entropy[sorted_losses_idx]\n            entropy_quantiles = []\n            for i, j in itertools.zip_longest(quantiles_idx, quantiles_idx[1:]):\n                entropy_quantiles.append(sorted_entropy[i:j].mean())\n            for quantile, value in zip(quantiles, entropy_quantiles):\n                tb_writer.add_scalar(f'{prefix}/entropy_quantile_{quantile:.3f}', value, step)\n\n    if len(metrics) > 4:\n        normalised_entropy = metrics[4].view(-1)\n        tb_writer.add_scalar(f'{prefix}/normalised_entropy', normalised_entropy.mean().item(), step)\n        if not args.no_quantiles:\n            assert normalised_entropy.size() == losses.size()\n            sorted_normalised_entropy = normalised_entropy[sorted_losses_idx]\n            normalised_entropy_quantiles = []\n            for i, j in itertools.zip_longest(quantiles_idx, quantiles_idx[1:]):\n                normalised_entropy_quantiles.append(sorted_normalised_entropy[i:j].mean())\n            for quantile, value in zip(quantiles, normalised_entropy_quantiles):\n                tb_writer.add_scalar(f'{prefix}/normalised_entropy_quantile_{quantile:.3f}', value, step)\n\n    if len(metrics) > 5:\n        log_likelihood = metrics[5].mean()\n        tb_writer.add_scalar(f'{prefix}/log_likelihood', log_likelihood.item(), step)\n        likelihood = torch.exp(-log_likelihood).item()\n        tb_writer.add_scalar(f'{prefix}/likelihood', likelihood, step)\n        perplexity = torch.exp(log_likelihood).item()\n        tb_writer.add_scalar(f'{prefix}/perplexity', perplexity, step)\n\n    if len(metrics) > 6:\n        mcfaddens_pseudo_r2 = metrics[6].mean()\n        tb_writer.add_scalar(f'{prefix}/mcfaddens_pseudo_r2', mcfaddens_pseudo_r2.item(), step)\n\n    if len(metrics) > 7:\n        tb_writer.add_scalar(f'{prefix}/top1_accuracy', metrics[7].mean().item(), step)\n        tb_writer.add_scalar(f'{prefix}/top5_accuracy', metrics[8].mean().item(), step)\n        tb_writer.add_scalar(f'{prefix}/top20_accuracy', metrics[9].mean().item(), step)\n\n    if len(metrics) > 10:\n        tb_writer.add_scalar(f'{prefix}/load_balancing_loss', metrics[10].mean().item(), step)\n\n    if len(metrics) > 11:\n        tb_writer.add_scalar(f'{prefix}/alternate_load_balancing_loss', metrics[11].mean().item(), step)\n\n    return loss\n\n\ndef evaluate_single(model_engine, name, eval_dataloader, tb_writer, step, eval_gradient_accumulation_steps):\n    orig_micro_batches = model_engine.micro_batches\n    model_engine.micro_batches = eval_gradient_accumulation_steps\n    iterator = iter(eval_dataloader)\n    all_metrics = None\n    while True:\n        metrics = model_engine.eval_batch(iterator)\n        eval_dataloader.sync_epoch()\n        if all_metrics is None:\n            all_metrics = [[] for _ in range(len(metrics))]\n        if eval_dataloader.epoch == 2:\n            break\n        for i, metric in enumerate(metrics):\n            all_metrics[i].append(metric)\n\n    eval_dataloader.reset()\n    model_engine.micro_batches = orig_micro_batches\n    eval_metrics = [torch.cat(metric_list) for metric_list in all_metrics]\n    loss = None\n    if is_main_process():\n        loss = write_metrics(tb_writer, f'eval/{name}', eval_metrics, step)\n    return loss\n\n\ndef evaluate(model_engine, eval_dataloaders, tb_writer, step, eval_gradient_accumulation_steps):\n    if is_main_process():\n        print('Running eval')\n    start = time.time()\n    loss = []\n    for name, eval_dataloader in eval_dataloaders.items():\n        loss_or_none = evaluate_single(\n            model_engine, name, eval_dataloader, tb_writer, step, eval_gradient_accumulation_steps\n        )\n        if loss_or_none is not None:\n            loss.append(loss_or_none)\n    duration = time.time() - start\n    if is_main_process():\n        tb_writer.add_scalar('eval/eval_time_sec', duration, step)\n    return sum(loss) / len(loss) if len(loss) > 0 else None\n\n\ndef apply_max_norm_regularization(model, config):\n    # modifed from https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py\n    A_keys = []\n    B_keys = []\n    norms = []\n    keys_scaled = 0\n    lora_scale = config['lora_alpha'] / config['lora_rank']\n\n    state_dict = model.state_dict()\n    for key in state_dict.keys():\n        if 'lora_A' in key:\n            A_keys.append(key)\n            B_keys.append(key.replace('lora_A', 'lora_B'))\n\n    for i in range(len(A_keys)):\n        A = state_dict[A_keys[i]]\n        B = state_dict[B_keys[i]]\n        W = B @ A\n        W *= lora_scale\n\n        if 'scale_weight_norms' in config:\n            max_norm = config['scale_weight_norms']\n            norm = W.norm().clamp(min=max_norm / 2)\n            desired = torch.clamp(norm, max=max_norm)\n            ratio = desired.cpu() / norm.cpu()\n            sqrt_ratio = ratio**0.5\n            if ratio != 1:\n                keys_scaled += 1\n                state_dict[A_keys[i]] *= sqrt_ratio\n                state_dict[B_keys[i]] *= sqrt_ratio\n        else:\n            ratio = 1.0\n        scalednorm = W.norm() * ratio\n        norms.append(scalednorm.item())\n\n    if len(norms) > 0:\n        norms = torch.tensor(norms, dtype=torch.float32)\n        if torch.any(torch.isnan(norms)):\n            raise RuntimeError('NaN detected in norms, probably some/all weights are NaN')\n        avg_norm = sum(norms) / len(norms)\n        max_norm = max(norms)\n    else:\n        avg_norm = 0\n        max_norm = 0\n    return keys_scaled, avg_norm, max_norm, norms\n\n\ndef parse_layers_to_transform(spec):\n    parts = spec.split(',')\n    result = []\n    for part in parts:\n        start, stop = part.split(':')\n        result.extend(range(int(start), int(stop) + 1))\n    return result\n\n\n@contextmanager\ndef one_at_a_time():\n    for i in range(int(os.environ['LOCAL_SIZE'])):\n        if i == int(os.environ['LOCAL_RANK']):\n            yield\n        deepspeed.comm.barrier()\n\n\ndef load_pipeline_model_with_lora(config, model_type):\n    full_fine_tune = config['full_fine_tune']\n\n    if config.get('quantization', None):\n        assert not full_fine_tune\n        no_quant_modules = ['lm_head']\n        if model_type == 'mixtral':\n            # the expert routing weights are tiny and probably important, don't quantize\n            no_quant_modules.append('gate')\n        if bnb_quant_config := config['quantization'].get('bnb', None):\n            if bnb_compute_dtype := bnb_quant_config.get('bnb_4bit_compute_dtype', None):\n                bnb_quant_config['bnb_4bit_compute_dtype'] = DTYPE_MAP[bnb_compute_dtype]\n            if 'bnb_4bit_quant_type' not in bnb_quant_config:\n                # Always want to default to nf4 if not specified.\n                bnb_quant_config['bnb_4bit_quant_type'] = 'nf4'\n            if llm_int8_skip_modules := bnb_quant_config.get('llm_int8_skip_modules', None):\n                no_quant_modules.extend(llm_int8_skip_modules)\n                no_quant_modules = list(set(no_quant_modules))\n            bnb_quant_config['llm_int8_skip_modules'] = no_quant_modules\n            quantization_config = transformers.BitsAndBytesConfig(**bnb_quant_config)\n        elif hqq_quant_config := config['quantization'].get('hqq', None):\n            quantization_config = hqq_utils.CustomHQQConfig(**hqq_quant_config)\n            # Use ATEN backend if possible, else PYTORCH. PYTORCH_COMPILE was only a tiny bit faster, and requires triton nightly.\n            hqq_quantize.HQQLinear.set_backend(\n                hqq_quantize.HQQBackend.ATEN if quantization_config.use_aten() else hqq_quantize.HQQBackend.PYTORCH\n            )\n        else:\n            raise NotImplementedError('Invalid quantization config')\n        if is_main_process():\n            print(f'Quantization config: {quantization_config}')\n    else:\n        quantization_config = None\n\n    if model_type == 'llama':\n        model = models.LlamaForCausalLMPipe(config, quantization_config=quantization_config)\n    elif model_type == 'mixtral':\n        model = models.MixtralForCausalLMPipe(config, quantization_config=quantization_config)\n    elif model_type == 'qwen2':\n        model = models.Qwen2ForCausalLMPipe(config, quantization_config=quantization_config)\n    elif model_type == 'cohere':\n        model = models.CohereForCausalLMPipe(config, quantization_config=quantization_config)\n    elif model_type == 'phi3':\n        model = models.Phi3ForCausalLMPipe(config, quantization_config=quantization_config)\n    elif model_type == 'gemma2':\n        model = models.Gemma2ForCausalLMPipe(config, quantization_config=quantization_config)\n    elif model_type == 'mistral' or model_type == 'mistral3':\n        model = models.MistralForCausalLMPipe(config, quantization_config=quantization_config)\n    elif model_type == 'gemma3':\n        model = models.Gemma3ForCausalLMPipe(config, quantization_config=quantization_config)\n    elif model_type == 'cohere2':\n        model = models.Cohere2ForCausalLMPipe(config, quantization_config=quantization_config)\n    else:\n        raise NotImplementedError(f'model_type {model_type} is not implemented')\n\n    # CAREFUL! The \"primary\" layers of the model have to have 'decoderlayer' in them for\n    # activation checkpointing to automatically work correctly.\n    layers = model.to_layer_specs()\n    checkpointable_layers = set()\n    for layer in layers:\n        if isinstance(layer, LayerSpec) and 'decoderlayer' in layer.typename.__name__.lower():\n            checkpointable_layers.add(layer.typename.__name__)\n    checkpointable_layers = list(checkpointable_layers)\n\n    partition_method = 'estimated_size'\n    if config['activation_checkpointing']:\n        # NOTE: must use a reentrant checkpointing function for MLP offloading to work.\n        if config['activation_checkpointing'] == 'unsloth':\n            checkpoint_func = unsloth_utils.unsloth_checkpoint\n        elif config['activation_checkpointing'] == 'cpu':\n            deepspeed.checkpointing.configure(None, checkpoint_in_cpu=True)\n            checkpoint_func = deepspeed.checkpointing.checkpoint\n        else:\n            checkpoint_func = deepspeed.checkpointing.checkpoint\n        pipeline_model = engine.CustomPipelineModule(\n            layers=layers,\n            num_stages=config['pipeline_stages'],\n            activation_checkpoint_interval=1,\n            checkpointable_layers=checkpointable_layers,\n            activation_checkpoint_func=checkpoint_func,\n            partition_method=partition_method,\n            use_column_major_topology=config.get('use_column_major_topology', False),\n            model=model,\n            dynamic_shape=True,\n        )\n    else:\n        pipeline_model = engine.CustomPipelineModule(\n            layers=layers,\n            num_stages=config['pipeline_stages'],\n            partition_method=partition_method,\n            use_column_major_topology=config.get('use_column_major_topology', False),\n        )\n\n    target_modules = config['target_modules'] if 'target_modules' in config else 'all-linear'\n    if full_fine_tune:\n        lora_model = None\n        lora_config = None\n        for name, p in model.named_parameters():\n            p.original_name = name\n        if isinstance(target_modules, list):\n            for name, p in model.named_parameters():\n                if not any(target in name for target in config['target_modules']):\n                    p.requires_grad = False\n                    print(f'not training {name} because it is not present in target_modules')\n    else:\n        layers_to_transform = (\n            parse_layers_to_transform(config['layers_to_transform']) if 'layers_to_transform' in config else None\n        )\n        lora_config = LoraConfig(\n            r=config['lora_rank'],\n            lora_alpha=config['lora_alpha'],\n            target_modules=target_modules,\n            modules_to_save=config['modules_to_save'] if 'modules_to_save' in config else [],\n            lora_dropout=config['lora_dropout'] if 'lora_dropout' in config else 0,\n            layers_to_transform=layers_to_transform,\n            bias='none',\n            task_type='CAUSAL_LM',\n            use_dora=config.get('use_dora', False),\n        )\n\n        lora_model = get_peft_model(model, lora_config)\n        # If the underlying weights are floats, the lora weights have already been\n        # cast to the same dtype, so we need to change the dtype here.\n        for p in lora_model.parameters():\n            if p.requires_grad:\n                p.data = p.data.to(DTYPE_MAP[config.get('lora_weight_dtype', 'float32')])\n\n        lora_model.model.config.use_cache = False\n        for name, p in lora_model.named_parameters():\n            p.original_name = name\n\n    return pipeline_model, lora_model, lora_config\n\n\nif __name__ == '__main__':\n    # TODO: if resuming from checkpoint, probably should read all config files from checkpoint dir\n    # rather than assume they are unchanged on the command line\n    with open(args.config) as f:\n        config = toml.load(f)\n    set_config_defaults(config)\n\n    if hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None:\n        # engine.initialize() will load deepspeed config from args\n        ds_config = None\n    else:\n        # The necessary ds_config fields are taken from the TOML config file.\n        ds_config = {\n            'train_micro_batch_size_per_gpu': config.get('micro_batch_size_per_gpu', 1),\n            'gradient_accumulation_steps': config.get('gradient_accumulation_steps', 1),\n            'gradient_clipping': config.get('gradient_clipping', 1.0),\n            'steps_per_print': config.get('steps_per_print', 1),\n        }\n\n    resume_from_checkpoint = (\n        args.resume_from_checkpoint\n        if args.resume_from_checkpoint is not None\n        else config['resume_from_checkpoint']\n        if 'resume_from_checkpoint' in config\n        else False\n    )\n\n    deepspeed.init_distributed()\n\n    with open(os.path.join(config['model'], 'config.json')) as f:\n        model_config = json.load(f)\n        model_type = model_config.get('model_type', 'llama')\n\n    # Pad on left to support training techniques that involve sampling from the model.\n    tokenizer = transformers.AutoTokenizer.from_pretrained(\n        config['model'], local_files_only=True, model_max_length=int(1e30), padding_side='left'\n    )\n    # TODO: do we want to do this with cohere models? By default the EOS token is <|END_OF_TURN_TOKEN|>\n    # if model_type == 'cohere':\n    #     tokenizer.eos_token = '<EOS_TOKEN>'\n    if tokenizer.pad_token is None:\n        tokenizer.pad_token = tokenizer.eos_token\n\n    train_data, eval_data_map = load_datasets(config, tokenizer)\n\n    if args.debug_dataset:\n        if is_main_process():\n            for i, item in enumerate(iter(train_data)):\n                print('input_ids:')\n                print(item['input_ids'])\n                print('decoded input_ids:')\n                print(tokenizer.decode(item['input_ids']))\n                print('attention_mask:')\n                print(item['attention_mask'])\n                print('labels:')\n                print(item['labels'])\n                if 'rejected_input_ids' in item:\n                    print('rejected_input_ids:')\n                    print(item['rejected_input_ids'])\n                    print('decoded rejected_input_ids:')\n                    print(tokenizer.decode(item['rejected_input_ids']))\n                    print('rejected_attention_mask:')\n                    print(item['rejected_attention_mask'])\n                    print('rejected_labels:')\n                    print(item['rejected_labels'])\n                print('-' * 80)\n                if i >= args.debug_dataset - 1:\n                    break\n        quit()\n\n    # for testing\n    # train_data = train_data.select(list(range(100)))\n\n    # if this is a new run, create a new dir for it\n    if not resume_from_checkpoint and is_main_process():\n        run_dir = os.path.join(config['output_dir'], datetime.now(timezone.utc).strftime('%Y%m%d_%H-%M-%S'))\n        os.makedirs(run_dir, exist_ok=True)\n        shutil.copy(args.config, run_dir)\n        if hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None:\n            shutil.copy(args.deepspeed_config, run_dir)\n    # wait for all processes then get the most recent dir (may have just been created)\n    deepspeed.comm.barrier()\n    run_dir = get_most_recent_run_dir(config['output_dir'])\n\n    # Ugly hack so we can move quantized models from GPU to CPU, and back to GPU again without triggering quantization a second time.\n    bnb_cuda_old = bitsandbytes.nn.modules.Params4bit.cuda\n\n    def bnb_cuda_hijack(self, device):\n        if getattr(self, 'already_quantized', False):\n            self.data = self.data.to(device)\n            self.quant_state.to(device)\n            return self\n        self.already_quantized = True\n        return bnb_cuda_old(self, device)\n\n    bitsandbytes.nn.modules.Params4bit.cuda = bnb_cuda_hijack\n\n    pipeline_model, lora_model, lora_config = load_pipeline_model_with_lora(config, model_type)\n\n    parameters_to_train = [p for p in pipeline_model.parameters() if p.requires_grad]\n\n    optim_config = config['optimizer']\n\n    def get_optimizer(model_parameters):\n        lr = optim_config['lr']\n        optim_type = optim_config['type'].lower()\n        optimizer_kwargs = {\n            'params': model_parameters,\n            'lr': lr,\n            'betas': (optim_config.get('beta1', 0.9), optim_config.get('beta2', 0.99)),\n            'weight_decay': optim_config.get('weight_decay', 0.01),\n            'eps': optim_config.get('eps', 1e-6),\n        }\n        if optim_type == 'adamw':\n            optimizer_cls = deepspeed.ops.adam.FusedAdam\n        elif optim_type == 'adamw8bit':\n            optimizer_cls = bitsandbytes.optim.AdamW8bit\n        elif optim_type == 'adamw_kahan':\n            import optimi\n\n            optimizer_cls = optimi.AdamW\n            optimizer_kwargs['kahan_sum'] = optim_config.get('kahan_sum', True)\n        else:\n            raise NotImplementedError(optim_type)\n        if optim_config.get('use_loraplus', False):\n            loraplus_lr_ratio = optim_config.get('loraplus_lr_ratio', 16)\n            # TODO: handle params being thrown out here; why is it included in the first place?\n            # delete 'params' from optimizer_kwargs\n            del optimizer_kwargs['params']\n            return create_loraplus_optimizer(\n                model=pipeline_model,\n                optimizer_cls=optimizer_cls,\n                loraplus_lr_ratio=loraplus_lr_ratio,\n                **optimizer_kwargs,\n            )\n        return optimizer_cls(**optimizer_kwargs)\n\n    kwargs = {}\n    if sampling_settings := config.get('sampling', None):\n        for k, v in sampling_settings.items():\n            kwargs['sampling_' + k] = v\n    rejected_sampling = config.get('rejected_sampling', False)\n    rl_config=config.get('rl', None)\n    model_engine, optimizer = engine.initialize(\n        args=args,\n        model=pipeline_model,\n        model_parameters=parameters_to_train,\n        optimizer=get_optimizer,\n        lora_model=lora_model,\n        config=ds_config,\n        tokenizer=tokenizer,\n        rl_config=rl_config,\n        rejected_sampling=rejected_sampling,\n        rejected_sampling_max_new_tokens=config.get('rejected_sampling_max_new_tokens', 1e9),\n        **kwargs,\n    )\n\n    # TODO: I have recently realized that we are setting things to fp16/bf16, even though all the DS\n    # config was not in fp16 / bf16 mode. DS being in fp16/bf16 changes things in many places, e.g.\n    # it can give you a BF16_Optimizer wrapper that accumulates grads in fp32, the communication dtype\n    # is different, etc. I need to really look through all the implications of this. This change is so\n    # that we keep the normal optimizer, but the communication dtype is changed so that we don't\n    # unnecessarily cast grads to fp32.\n    weight_dtype = DTYPE_MAP[config.get('lora_weight_dtype', config.get('model_weight_dtype', 'float32'))]\n    model_engine.communication_data_type = weight_dtype\n\n    # TODO: the main DeepSpeedEngine forces all parameters to the GPU, and also does things like\n    # broadcast all parameters from data parallel rank 0 to all other ranks. Thus, MLP offloading\n    # must come after engine.initialize(). If we want to avoid loading everything onto GPUs only\n    # to offload the MLPs, we have to rewrite a lot of code to work around things.\n    if config.get('offload_mlp_to_cpu', False):\n        assert config['activation_checkpointing']  # MLP offloading only works with activation checkpointing\n        for module in pipeline_model.modules():\n            if hasattr(module, 'move_mlp_to_cpu'):\n                module.move_mlp_to_cpu()\n        torch.cuda.empty_cache()\n\n    train_dataloader = dataloader.PipelineDataLoader(\n        train_data,\n        tokenizer,\n        model_engine.train_micro_batch_size_per_gpu(),\n        model_engine.gradient_accumulation_steps(),\n        model_engine.grid.get_data_parallel_world_size(),\n        model_engine.grid.get_data_parallel_rank(),\n        group_by_length=False if 'group_by_length' not in config else config['group_by_length'],\n        batch_size_tokens=None if 'batch_size_tokens' not in config else config['batch_size_tokens'],\n        return_dict=rejected_sampling,\n        rl=(rl_config is not None),\n    )\n    model_engine.set_dataloader(train_dataloader)\n    steps_per_epoch = len(train_dataloader) // model_engine.gradient_accumulation_steps()\n    model_engine.total_steps = steps_per_epoch * config['epochs']\n\n    if is_main_process():\n        # Warn if eval dataset is unusually large compared to the eval steps\n        eval_data_length = sum([len(eval_data) for eval_data in eval_data_map.values()])\n        train_data_length = len(train_data)\n        evals_per_epoch = steps_per_epoch / config['eval_steps']\n        relative_eval_time = evals_per_epoch * eval_data_length\n        # train step very roughly 3 times slower due to backprop + usually activation checkpointing is enabled\n        relative_train_time = train_data_length * 3\n        # Expect <=15% of our time spent evaluating vs training\n        fraction_evaling = relative_eval_time / (relative_eval_time + relative_train_time)\n        print()\n        print(\n            f'eval_data_length: {eval_data_length}, eval_steps: {config[\"eval_steps\"]}; evals per epoch: {evals_per_epoch}. '\n            f'We will be spending approximately {fraction_evaling * 100:.2f}% of our time evaluating.'\n        )\n        if fraction_evaling > 0.15:\n            print(\n                'WARNING: eval dataset is unusually large compared to eval_steps. We will spend a lot of time evaluating. Lowering eval_size and/or bumping eval_steps is recommended.'\n            )\n        print()\n\n    # handle Deepspeed optimizer wrapper (e.g. BF16_Optimizer)\n    optimizer = getattr(optimizer, 'optimizer', optimizer)\n\n    warmup_steps = config.get('warmup_steps', 0)\n    # Fractional values less than 1 are converted into \"fraction of epoch\" worth of steps\n    if 0 < warmup_steps < 1:\n        warmup_steps = int(warmup_steps * steps_per_epoch)\n\n    if 'lr_scheduler' not in config or config['lr_scheduler'] == 'constant' or config['lr_scheduler'] == 'none':\n        lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)\n    elif config['lr_scheduler'] == 'cosine':\n        total_steps = steps_per_epoch * config['epochs']\n        total_steps -= warmup_steps\n        lr_scheduler_kwargs = {\n            'optimizer': optimizer,\n            'T_max': total_steps,\n        }\n        if 'lr_min' in optim_config:\n            lr_scheduler_kwargs['eta_min'] = optim_config['lr_min']\n\n        # Normally, you would pass the lr_scheduler to deepspeed.initialize(). But we need the\n        # global batch_size in order to make the lr_scheduler.\n        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(**lr_scheduler_kwargs)\n    else:\n        raise NotImplementedError()\n\n    load_optimizer_states = config.get('load_optimizer_states', True)\n    # if resuming and not loading optimizer states, we can't use warmup or the LR never changes from the initial value (still don't know why)\n    if warmup_steps > 0 and load_optimizer_states:\n        warmup_scheduler = torch.optim.lr_scheduler.LinearLR(\n            optimizer, start_factor=1 / warmup_steps, total_iters=warmup_steps\n        )\n        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(\n            optimizer, schedulers=[warmup_scheduler, lr_scheduler], milestones=[warmup_steps]\n        )\n\n    model_engine.lr_scheduler = lr_scheduler\n\n    step = 1\n    if resume_from_checkpoint:\n        load_path, client_state = model_engine.load_checkpoint(\n            run_dir,\n            load_module_strict=False,\n            load_lr_scheduler_states='force_constant_lr' not in config,\n            load_optimizer_states=load_optimizer_states,\n        )\n        deepspeed.comm.barrier()  # just so the print below doesn't get swamped\n        assert load_path is not None\n        train_dataloader.load_state_dict(client_state['custom_loader'])\n        step = client_state['step'] + 1\n        del client_state\n        # if we skip loading the optimizer states, we need to step the LR scheduler so we start at the right value\n        if not load_optimizer_states:\n            model_engine.lr_scheduler.step()\n        if is_main_process():\n            print(f'Resuming training from checkpoint. Resuming at epoch: {train_dataloader.epoch}, step: {step}')\n\n    if 'force_constant_lr' in config:\n        model_engine.lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)\n        for pg in optimizer.param_groups:\n            pg['lr'] = config['force_constant_lr']\n\n    # this is a separate option, because if it's too high we might drop a significant fraction of the eval dataset\n    eval_gradient_accumulation_steps = (\n        config['eval_gradient_accumulation_steps'] if 'eval_gradient_accumulation_steps' in config else 1\n    )\n    # Eval dataset doesn't need to repeat; we just use this to track \"epoch\" so we know when we're done iterating over it.\n    eval_dataloaders = {\n        name: dataloader.PipelineDataLoader(\n            eval_data,\n            tokenizer,\n            model_engine.train_micro_batch_size_per_gpu(),\n            eval_gradient_accumulation_steps,\n            model_engine.grid.get_data_parallel_world_size(),\n            model_engine.grid.get_data_parallel_rank(),\n            shuffle=False,\n            group_by_length=False if 'group_by_length' not in config else config['group_by_length'],\n            batch_size_tokens=None if 'batch_size_tokens' not in config else config['batch_size_tokens'],\n            return_dict=rejected_sampling,\n            rl=(rl_config is not None),\n        )\n        for name, eval_data in eval_data_map.items()\n    }\n\n    tb_writer = SummaryWriter(log_dir=run_dir) if is_main_process() else None\n\n    epoch = train_dataloader.epoch\n\n    saver = Saver(model_engine, pipeline_model, train_dataloader, lora_config, run_dir, args, config)\n\n    epoch = train_dataloader.epoch\n\n    if config.get('eval_before_first_step', False) and not resume_from_checkpoint:\n        loss = evaluate(model_engine, eval_dataloaders, tb_writer, 0, eval_gradient_accumulation_steps)\n        saver.append_eval_results(loss, save_best=False)\n\n    while True:\n        metrics = model_engine.train_batch()\n        train_dataloader.sync_epoch()\n        if lora_config is not None:\n            keys_scaled, avg_norm, max_norm, norms = apply_max_norm_regularization(pipeline_model, config)\n\n        new_epoch = saver.process_epoch(epoch, step)\n        finished_epoch = True if new_epoch != epoch else False\n\n        if is_main_process() and step % config['logging_steps'] == 0:\n            write_metrics(tb_writer, 'train', metrics, step)\n            tb_writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], step)\n            # TODO: gather the weight norms across all stages in the pipelined model, not just the first.\n            if lora_config is not None and len(norms) > 0:\n                tb_writer.add_scalar('train/weights_scaled', keys_scaled, step)\n                tb_writer.add_scalar('train/weight_norm_avg', avg_norm, step)\n                tb_writer.add_scalar('train/weight_norm_max', max_norm, step)\n                tb_writer.add_histogram('train/weight_norm_hist', norms, step)\n            tb_writer.add_scalar('train/epoch', step / steps_per_epoch, step)\n\n        if step % config['eval_steps'] == 0:\n            loss = evaluate(model_engine, eval_dataloaders, tb_writer, step, eval_gradient_accumulation_steps)\n            saver.append_eval_results(loss)\n\n        saver.process_step(step)\n\n        if finished_epoch:\n            epoch = new_epoch\n            if epoch is None:\n                break\n\n        step += 1\n\n    if ((step - 1) % config['eval_steps'] != 0) and config.get('eval_after_last_step', False):\n        loss = evaluate(model_engine, eval_dataloaders, tb_writer, step - 1, eval_gradient_accumulation_steps)\n        saver.append_eval_results(loss)\n\n    if is_main_process():\n        print('TRAINING COMPLETE!')\n"
  },
  {
    "path": "utils/dataloader.py",
    "content": "import math\nimport os.path\nimport sys\n\n\nsys.path.insert(0, os.path.abspath('axolotl/src'))\n\nimport accelerate\nimport torch\nimport transformers\nfrom deepspeed import comm as dist\nfrom torch.utils.data import DataLoader\n\nfrom axolotl.utils.collators import DataCollatorForSeq2Seq\n\n\nfrom utils.utils import is_main_process\n\n\n# A100 wants padding to multiple of 64, other cards are efficient with smaller, so just do 64\nPAD_TO_MULTIPLE = 64\n\n\n# Splits an example (feature dict) along the batch dimension into a list of examples.\ndef split_batch(example, pieces):\n    input_ids = example['input_ids']\n    if is_main_process():\n        print(f'before GAS splitting, input_ids shape: {input_ids.shape}, total tokens: {input_ids.numel()}')\n    input_batch_size = input_ids.size(0)\n    split_size = input_batch_size // pieces\n    examples = [{} for _ in range(pieces)]\n    for key, tensor in example.items():\n        assert tensor.size(0) == input_batch_size\n        for i, tensor_slice in enumerate(torch.split(tensor, split_size)):\n            examples[i][key] = tensor_slice\n    return examples\n\n\n# Merge lists of examples a and b, such that for each contiguous piece in the result, the first half comes from\n# a and the second half from b. Used for DPO. The splitting must match how split_batch() does it.\ndef combine_piecewise(a, b, pieces):\n    assert len(a) == len(b)\n    split_size = len(a) // pieces\n    a_chunks = [a[i : i + split_size] for i in range(0, len(a), split_size)]\n    b_chunks = [b[i : i + split_size] for i in range(0, len(b), split_size)]\n    result = []\n    for a_chunk, b_chunk in zip(a_chunks, b_chunks):\n        result.extend(a_chunk)\n        result.extend(b_chunk)\n    return result\n\n\n# Flattens a list of examples with batch dimension into a list of examples with no batch dimension.\ndef flatten_examples(examples):\n    result = []\n    for example in examples:\n        batch_size = example['input_ids'].size(0)\n        new_examples = [{} for _ in range(batch_size)]\n        for key, tensor in example.items():\n            assert tensor.size(0) == batch_size\n            for i, tensor_slice in enumerate(tensor):\n                new_examples[i][key] = tensor_slice\n        result.extend(new_examples)\n    return result\n\n\ndef example_to_tuple(example):\n    return (example['input_ids'], example['attention_mask'], example['labels']), None\n\n\ndef shuffle_list(l, seed):\n    g = torch.Generator()\n    g.manual_seed(seed)\n    shuffle_idx = torch.randperm(len(l), generator=g).tolist()\n    new_l = [l[i] for i in shuffle_idx]\n    return new_l\n\n\ndef batch_size_tokens_after_padding(batch):\n    return max(math.ceil(pair[1] / PAD_TO_MULTIPLE) * PAD_TO_MULTIPLE for pair in batch) * len(batch)\n\n\n# A distributed batch sampler that supports grouping by length\nclass DistributedBatchSamper(torch.utils.data.Sampler):\n    def __init__(\n        self,\n        dataset,\n        batch_size,\n        num_replicas,\n        rank,\n        batch_size_multiplier=1,\n        shuffle=True,\n        group_by_length=False,\n        seed=0,\n        batch_size_tokens=None,\n    ):\n        self.dataset = dataset\n        self.batch_size = batch_size\n        self.batch_size_tokens = batch_size_tokens\n        self.batch_size_multiplier = batch_size_multiplier\n        self.num_replicas = num_replicas\n        self.rank = rank\n        # every global batch must be evenly divisible by this amount\n        self.chunk_size = self.num_replicas * self.batch_size_multiplier\n        self.shuffle = shuffle\n        self.group_by_length = group_by_length\n        self.seed = seed\n\n        # Make list of (index, size). Sort or shuffle as needed.\n        indices = list(enumerate(self.dataset['length']))\n        if self.group_by_length:\n            indices.sort(key=lambda t: t[1])\n        elif self.shuffle:\n            indices = shuffle_list(indices, self.seed)\n\n        # Group indices together into global batches.\n        global_batches = []\n        current_batch = []\n        for i in range(0, len(indices), self.chunk_size):\n            slice = indices[i : i + self.chunk_size]\n            if len(slice) < self.chunk_size:\n                # pad with random examples if slice is too small\n                padding_size = self.chunk_size - len(slice)\n                shuffled_indices = shuffle_list(indices, self.seed + 1)\n                if padding_size < len(shuffled_indices):\n                    slice += shuffled_indices[:padding_size]\n                else:\n                    slice += (shuffled_indices * math.ceil(padding_size / len(shuffled_indices)))[:padding_size]\n\n            if self.should_emit_current_batch(current_batch, slice):\n                global_batches.append(current_batch)\n                current_batch = []\n            current_batch.extend(slice)\n\n        # Emit anything remaining\n        if len(current_batch) > 0:\n            global_batches.append(current_batch)\n\n        if self.shuffle:\n            global_batches = shuffle_list(global_batches, self.seed + 2)\n\n        # make sure the largest batch comes first to OOM sooner rather than later\n        largest_global_batch = 0\n        max_tokens = 0\n        for global_batch_idx, batch in enumerate(global_batches):\n            total_batch_tokens = batch_size_tokens_after_padding(batch)\n            if total_batch_tokens > max_tokens:\n                max_tokens = total_batch_tokens\n                largest_global_batch = global_batch_idx\n        global_batches[0], global_batches[largest_global_batch] = (\n            global_batches[largest_global_batch],\n            global_batches[0],\n        )\n\n        batches_for_this_rank = [\n            global_batch[self.rank : len(global_batch) : self.num_replicas] for global_batch in global_batches\n        ]\n        self.indices = [[i for i, _ in batch] for batch in batches_for_this_rank]\n\n    def should_emit_current_batch(self, current_batch, slice):\n        if not self.batch_size_tokens:\n            batch_size_after_appending = len(current_batch) // self.chunk_size + 1\n            if batch_size_after_appending > self.batch_size:\n                return True\n            else:\n                return False\n        else:\n            global_batch_size_tokens = self.batch_size_tokens * self.chunk_size\n            current_batch_tokens_after_appending = batch_size_tokens_after_padding(current_batch + slice)\n            if len(current_batch) > 0 and current_batch_tokens_after_appending > global_batch_size_tokens:\n                return True\n            return False\n\n    def __iter__(self):\n        return iter(self.indices)\n\n    def __len__(self):\n        return len(self.indices)\n\n\nclass PipelineDataLoader:\n    def __init__(\n        self,\n        dataset,\n        tokenizer,\n        batch_size,\n        gradient_accumulation_steps,\n        data_parallel_world_size,\n        data_parallel_rank,\n        shuffle=True,\n        group_by_length=False,\n        pad_to_multiple_of=PAD_TO_MULTIPLE,\n        batch_size_tokens=None,\n        return_dict=False,\n        rl=False,\n    ):\n        assert data_parallel_rank < data_parallel_world_size\n        self.dataset = dataset\n        self.tokenizer = tokenizer\n        self.batch_size = batch_size\n        self.batch_size_tokens = batch_size_tokens\n        self.gradient_accumulation_steps = gradient_accumulation_steps\n        self.pad_to_multiple_of = pad_to_multiple_of\n        self.return_dict = return_dict\n        self.rl = rl\n        self.data_sampler = DistributedBatchSamper(\n            dataset=dataset,\n            batch_size=self.batch_size,\n            batch_size_tokens=self.batch_size_tokens,\n            batch_size_multiplier=self.gradient_accumulation_steps,\n            num_replicas=data_parallel_world_size,\n            rank=data_parallel_rank,\n            shuffle=shuffle,\n            group_by_length=group_by_length,\n        )\n\n        data_collator = DataCollatorForSeq2Seq(self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)\n        def collate_fn(examples, gradient_accumulation_steps=self.gradient_accumulation_steps):\n            has_batch_dimension = examples[0]['input_ids'].ndim == 2\n            if has_batch_dimension:\n                examples = flatten_examples(examples)\n            rejected_examples = []\n            for example in examples:\n                example.pop('length', None)\n                example.pop('token_type_ids', None)\n                rejected_example = {}\n                for key in list(example.keys()):\n                    if 'rejected_' in key:\n                        x = example.pop(key)\n                        # Just drop the rejected_ entries if not doing RL. This allows normal SFT on just the\n                        # accepted completions of a RL dataset.\n                        if self.rl:\n                            rejected_example[key.strip('rejected_')] = x\n                if rejected_example:\n                    rejected_examples.append(rejected_example)\n            if rejected_examples:\n                examples = combine_piecewise(examples, rejected_examples, gradient_accumulation_steps)\n            return data_collator(examples)\n        self.collate_fn = collate_fn\n\n        self.epoch = 1\n        self.num_batches_pulled = 0\n        self.next_micro_batch = None\n        self.recreate_dataloader = False\n        self._create_dataloader()\n        self.data = self._pull_batches_from_dataloader()\n\n    def reset(self):\n        self.epoch = 1\n        self.num_batches_pulled = 0\n        self.next_micro_batch = None\n        self.data = self._pull_batches_from_dataloader()\n\n    def __iter__(self):\n        return self\n\n    def __len__(self):\n        return len(self.data_sampler) * self.gradient_accumulation_steps\n\n    def __next__(self):\n        if self.next_micro_batch is None:\n            self.next_micro_batch = next(self.data)\n        ret = self.next_micro_batch\n        try:\n            self.next_micro_batch = next(self.data)\n        except StopIteration:\n            if self.recreate_dataloader:\n                self._create_dataloader()\n                self.recreate_dataloader = False\n            self.data = self._pull_batches_from_dataloader()\n            self.num_batches_pulled = 0\n            self.next_micro_batch = next(self.data)\n            self.epoch += 1\n        return ret\n\n    def _pull_batches_from_dataloader(self):\n        for batch in self.dataloader:\n            self.num_batches_pulled += 1\n            for micro_batch in split_batch(batch, self.gradient_accumulation_steps):\n                if self.return_dict:\n                    yield micro_batch\n                else:\n                    # input to pipeline is (input_ids, attention_mask, labels)\n                    # this needs to return (features, labels)\n                    # it is OK if labels is None (the model just returns the loss anyway)\n                    yield example_to_tuple(micro_batch)\n\n    def _create_dataloader(self):\n        self.dataloader = DataLoader(\n            self.dataset,\n            pin_memory=True,\n            batch_sampler=self.data_sampler,\n            collate_fn=self.collate_fn,\n        )\n\n    def state_dict(self):\n        return {\n            'epoch': self.epoch,\n            'num_batches_pulled': self.num_batches_pulled,\n        }\n\n    def load_state_dict(self, state_dict):\n        self.epoch = state_dict['epoch']\n        # -1 because by preloading the next micro_batch, it's always going to have one more batch\n        # pulled than the actual number of batches iterated by the caller.\n        self.num_batches_pulled = state_dict['num_batches_pulled'] - 1\n        self._create_dataloader()\n        self.dataloader = accelerate.skip_first_batches(self.dataloader, self.num_batches_pulled)\n        self.data = self._pull_batches_from_dataloader()\n        # Recreate the dataloader after the first pass so that it won't skip\n        # batches again (we only want it to skip batches the first time).\n        self.recreate_dataloader = True\n\n    # Only the first and last stages in the pipeline pull from the dataloader. Parts of the code need\n    # to know the epoch, so we synchronize the epoch so the processes that don't use the dataloader\n    # know the current epoch.\n    def sync_epoch(self):\n        process_group = dist.get_world_group()\n        result = [None] * dist.get_world_size(process_group)\n        torch.distributed.all_gather_object(result, self.epoch, group=process_group)\n        max_epoch = -1\n        for epoch in result:\n            max_epoch = max(epoch, max_epoch)\n        self.epoch = max_epoch\n\n\n# for testing\nif __name__ == '__main__':\n    tokenizer = transformers.AutoTokenizer.from_pretrained(sys.argv[1], local_files_only=True)\n    tokenizer.pad_token_id = 1000\n\n    from datasets import Dataset\n\n    data = []\n    for i in range(1, 41):\n        input_ids = torch.tensor([i] * i)\n        data.append(\n            {\n                'input_ids': input_ids,\n                'attention_mask': torch.ones_like(input_ids),\n                'labels': input_ids,\n                'length': len(input_ids),\n            }\n        )\n    dataset = Dataset.from_list(data)\n\n    # dataloader = PipelineDataLoader(dataset, tokenizer, batch_size=2, gradient_accumulation_steps=2, data_parallel_world_size=1, data_parallel_rank=0, group_by_length=True, pad_to_multiple_of=None)\n    # for batch in dataloader:\n    #     if dataloader.epoch > 1:\n    #         break\n    #     print(batch)\n    #     print()\n\n    batch_size = 2\n    gradient_accumulation_steps = 2\n    data_parallel_world_size = 2\n    data_parallel_rank = 0\n    dataloader = PipelineDataLoader(\n        dataset,\n        tokenizer,\n        batch_size=batch_size,\n        gradient_accumulation_steps=gradient_accumulation_steps,\n        data_parallel_world_size=data_parallel_world_size,\n        data_parallel_rank=data_parallel_rank,\n        shuffle=False,\n        group_by_length=False,\n        pad_to_multiple_of=None,\n    )\n    print(next(dataloader)[0][0])\n    print(next(dataloader)[0][0])\n    print(next(dataloader)[0][0])\n    print(next(dataloader)[0][0])\n\n    state_dict = dataloader.state_dict()\n    dataloader = PipelineDataLoader(\n        dataset,\n        tokenizer,\n        batch_size=batch_size,\n        gradient_accumulation_steps=gradient_accumulation_steps,\n        data_parallel_world_size=data_parallel_world_size,\n        data_parallel_rank=data_parallel_rank,\n        shuffle=False,\n        group_by_length=False,\n        pad_to_multiple_of=None,\n    )\n    dataloader.load_state_dict(state_dict)\n    print()\n    print('-' * 80)\n    print()\n    print(next(dataloader)[0][0])\n    print(next(dataloader)[0][0])\n    print(next(dataloader)[0][0])\n    print(next(dataloader)[0][0])\n"
  },
  {
    "path": "utils/dataset_utils.py",
    "content": "import os\nimport os.path\nimport sys\n\n\nsys.path.insert(0, os.path.abspath('axolotl/src'))\n\nimport datasets\nimport torch\nimport yaml\nfrom tqdm import tqdm\n\nfrom axolotl.utils.data import prepare_dataset\nfrom axolotl.utils.dict import DictDefault\nfrom utils.utils import is_main_process, zero_first\n\n\nNUM_PROC = min(64, os.cpu_count())\n\n\ndef yield_sequences_from_token_batch(tokenizer, token_batch, sequence_len):\n    # Initialize sequence_tokens with BOS token if it exists\n    sequence_tokens = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else []\n    for tokens in tqdm(token_batch):\n        tokens = tokens.tolist()\n        assert len(tokens) > 0, 'Empty tokens list'\n        if tokens[-1] != tokenizer.eos_token_id:\n            tokens.append(tokenizer.eos_token_id)\n        idx = 0\n        # Skip the auto-generated BOS token if present\n        if tokenizer.bos_token_id is not None and tokens[0] == tokenizer.bos_token_id:\n            idx += 1\n        while idx < len(tokens):\n            # Calculate how many tokens are needed to fill the sequence\n            need = sequence_len - len(sequence_tokens)\n            taken = tokens[idx : idx + need]\n            idx += len(taken)\n            sequence_tokens.extend(taken)\n            if len(sequence_tokens) >= sequence_len:\n                assert len(sequence_tokens) == sequence_len\n                yield sequence_tokens\n                # Reset sequence_tokens with BOS token if it exists\n                sequence_tokens = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else []\n    # yield anything remaining\n    # TODO: disabled until I get training working with variable length sequences\n    # if len(sequence_tokens) > 0:\n    #     yield sequence_tokens\n\n\ndef slice_into_chunks(x, sequence_len, overlap=0):\n    result = []\n    step = sequence_len - overlap\n    for i in range(0, len(x), step):\n        result.append(x[i : i + sequence_len])\n    return result\n\n\ndef load_raw_dataset(dataset_path, tokenizer, sequence_len, eval_size, overlap=0, subsample_documents=None):\n    if dataset_path.endswith('.txt'):\n        dataset = datasets.load_dataset('text', data_files=dataset_path, sample_by='document')['train']\n    elif dataset_path.endswith('.json') or dataset_path.endswith('.jsonl'):\n        dataset = datasets.load_dataset('json', data_files=dataset_path)['train']\n    else:\n        raise NotImplementedError()\n    dataset.set_format(type='torch')\n\n    if subsample_documents:\n        dataset = dataset.shuffle(seed=13).select(list(range(int(subsample_documents * len(dataset)))))\n\n    dataset = dataset.map(\n        lambda x: tokenizer(x['text']),\n        batched=True,\n        batch_size=10,\n        remove_columns=dataset.column_names,\n        desc='tokenizing',\n        num_proc=NUM_PROC,\n    )\n    # TODO: maybe do it this way instead\n    # dataset = dataset.map(lambda x: {'tokens': slice_into_chunks(x['tokens'][0], sequence_len, overlap=overlap)}, batched=True, batch_size=1)\n    dataset = dataset.map(\n        lambda x: {'input_ids': list(yield_sequences_from_token_batch(tokenizer, x['input_ids'], sequence_len))},\n        batched=True,\n        batch_size=None,\n        remove_columns=dataset.column_names,\n        desc='splitting',\n    )\n    dataset = dataset.map(\n        lambda x: {'attention_mask': torch.ones_like(x['input_ids']), 'labels': x['input_ids']},\n        desc='adding attention_mask and labels',\n    )\n    if eval_size > 0:\n        split_datasets = dataset.train_test_split(test_size=eval_size, shuffle=True, seed=42)\n        train_data = split_datasets['train']\n        eval_data = split_datasets['test']\n    else:\n        train_data = dataset\n        eval_data = None\n    return train_data, eval_data\n\n\ndef load_axolotl_dataset(dataset_path, tokenizer, sequence_len, eval_size):\n    with open(dataset_path) as f:\n        cfg = yaml.safe_load(f.read())\n    if 'val_set_size' not in cfg:\n        cfg['val_set_size'] = 0 if eval_size is None else eval_size\n    cfg['sequence_len'] = sequence_len\n    cfg['tokenizer_config'] = 'dummy'\n    # these don't matter, but they have to be set\n    cfg['batch_size'] = 1\n    cfg['num_epochs'] = 1\n    cfg['sequence_parallel_degree'] = 1\n    cfg = DictDefault(cfg)\n    train_data, eval_data, *_ = prepare_dataset(cfg, tokenizer)\n    train_data.set_format(type='torch')\n    if eval_data is not None:\n        eval_data.set_format(type='torch')\n    return train_data, eval_data\n\n\ndef load_pretokenized_dataset(dataset_path, tokenizer, sequence_len, eval_size):\n    ds = datasets.load_from_disk(dataset_path)\n    assert 'input_ids' in ds.column_names\n    assert 'attention_mask' in ds.column_names\n    assert 'labels' in ds.column_names\n    ds = ds.filter(lambda example: len(example['input_ids']) <= sequence_len, desc='dropping long sequences', num_proc=NUM_PROC)\n    if eval_size > 0:\n        split_datasets = ds.train_test_split(test_size=eval_size, shuffle=True, seed=42)\n        train_data = split_datasets['train']\n        eval_data = split_datasets['test']\n    else:\n        train_data = ds\n        eval_data = None\n    return train_data, eval_data\n\n\ndef load_single_dataset(dataset_config, tokenizer):\n    dataset_path = dataset_config['dataset_path']\n    dataset_type = dataset_config['dataset_type']\n    sequence_len = dataset_config['sequence_len']\n    eval_size = dataset_config.get('eval_size', 0)\n    subsample = dataset_config.get('subsample', None)\n    num_repeats = dataset_config.get('num_repeats', None)\n    if dataset_type in ['textfile', 'doclist']:\n        with zero_first(is_main_process()):\n            train_data, eval_data = load_raw_dataset(dataset_path, tokenizer, sequence_len, eval_size)\n    elif dataset_type == 'axolotl':\n        train_data, eval_data = load_axolotl_dataset(dataset_path, tokenizer, sequence_len, eval_size)\n    elif dataset_type == 'pretokenized':\n        with zero_first(is_main_process()):\n            train_data, eval_data = load_pretokenized_dataset(dataset_path, tokenizer, sequence_len, eval_size)\n    else:\n        raise NotImplementedError()\n\n    train_data = train_data.shuffle(seed=42)\n    if eval_data is not None:\n        eval_data = eval_data.shuffle(seed=42)\n\n    if subsample is not None:\n        assert 0 < subsample < 1\n        train_data = train_data.select(range(int(len(train_data) * subsample)))\n        if eval_data is not None:\n            eval_data = eval_data.select(range(int(len(eval_data) * subsample)))\n\n    def add_length(x):\n        length = len(x['input_ids'])\n        if 'rejected_input_ids' in x:\n            length = max(length, len(x['rejected_input_ids']))\n        return {'length': length}\n\n    with zero_first(is_main_process()):\n        train_data = train_data.map(add_length, desc='adding length field', num_proc=NUM_PROC)\n        if eval_data is not None:\n            eval_data = eval_data.map(add_length, desc='adding length field', num_proc=NUM_PROC)\n\n    if 'prompt_attention_mask' in train_data.column_names:\n        train_data = train_data.remove_columns('prompt_attention_mask')\n        if eval_data is not None:\n            eval_data = eval_data.remove_columns('prompt_attention_mask')\n\n    if num_repeats:\n        train_data = train_data.repeat(num_repeats)\n\n    if is_main_process():\n        print(f'train_data size: {len(train_data)}')\n        if eval_data is not None:\n            print(f'eval_data size: {len(eval_data)}')\n    return train_data, eval_data\n\n\ndef combine_datasets(dataset_list, config, sample_weights):\n    sample_weights = torch.tensor(sample_weights, dtype=torch.float32)\n    mode = config.get('dataset_combination_mode', 'concatenate')\n    if mode == 'concatenate':\n        dataset = datasets.concatenate_datasets(dataset_list)\n    elif mode == 'interleave':\n        if 'batch_size_tokens' in config:\n            # batches are formed so they have equal token counts, so interleave datasets based on token counts, not rows\n            avg_lengths = torch.tensor(\n                [dataset['length'].to(torch.float32).mean() for dataset in dataset_list], dtype=torch.float32\n            )\n            sample_weights = sample_weights / avg_lengths\n        sample_weights = sample_weights.to(\n            torch.float64\n        )  # float64 or interleave_datasets complains that probs don't sum to 1\n        probs = sample_weights / sample_weights.sum()\n        dataset = datasets.interleave_datasets(\n            dataset_list,\n            probabilities=probs,\n            seed=42,\n            stopping_strategy=config.get('dataset_interleave_stopping_strategy', 'first_exhausted'),\n        )\n    else:\n        raise ValueError(mode)\n    return dataset\n\n\n# TODO: reduce the extra unneeded left padding caused by accepted and rejected being collated\n# together.\ndef process_dataset_for_rejected_sampling(dataset):\n\n    def _rejected_sampling_map_fn(example):\n        input_ids = example['input_ids']\n        attention_mask = example['attention_mask']\n        labels = example['labels']\n        assert input_ids.ndim == 1\n        assert attention_mask.ndim == 1\n        assert labels.ndim == 1\n\n        # Labels are -100 for masked tokens (the prompt).\n        label_mask = (labels == -100).to(torch.int32)\n\n        if (label_mask == 0).all():\n            # Raw text dataset\n            # TODO: allow configuring this\n            completion_start = len(label_mask) // 2\n            labels[:completion_start] = -100\n        else:\n            # index of first False\n            completion_start = torch.argmin(label_mask).item()\n        return {\n            'labels': labels,\n            'rejected_input_ids': input_ids[:completion_start],\n            'rejected_attention_mask': attention_mask[:completion_start],\n            'rejected_labels': labels[:completion_start],\n        }\n\n    return dataset.map(_rejected_sampling_map_fn, desc='Processing dataset for negative sampling', num_proc=NUM_PROC)\n\n\ndef load_datasets(config, tokenizer):\n    if 'datasets' not in config:\n        raise ValueError('Need to specify at least one dataset')\n    train_datasets = []\n    sample_weights = []\n    eval_datasets = {}\n    i = 0\n    for dataset_config in config['datasets']:\n        if 'name' in dataset_config:\n            name = dataset_config['name']\n        else:\n            name = f'dataset{i}'\n            i += 1\n        sample_weights.append(dataset_config.get('sample_weight', 1.0))\n        train, eval = load_single_dataset(dataset_config, tokenizer)\n        train_datasets.append(train)\n        if eval is not None:\n            eval_datasets[name] = eval\n\n    for dataset_config in config.get('eval_datasets', []):\n        if 'name' in dataset_config:\n            name = dataset_config['name']\n        else:\n            name = f'dataset{i}'\n            i += 1\n        eval, _ = load_single_dataset(dataset_config, tokenizer)\n        eval_datasets[name] = eval\n\n    if len(train_datasets) == 1:\n        train_dataset = train_datasets[0]\n    else:\n        with zero_first(is_main_process()):\n            train_dataset = combine_datasets(train_datasets, config, sample_weights=sample_weights)\n\n    if config.get('rejected_sampling', False):\n        assert 'rl' in config\n        train_dataset = process_dataset_for_rejected_sampling(train_dataset)\n        eval_datasets = {name: process_dataset_for_rejected_sampling(ds) for name, ds in eval_datasets.items()}\n\n    if 'rl' in config:\n        assert 'rejected_input_ids' in train_dataset.column_names\n        for eval_dataset in eval_datasets.values():\n            assert 'rejected_input_ids' in eval_dataset.column_names\n\n    train_dataset.set_format(type='torch')\n    for eval_dataset in eval_datasets.values():\n        eval_dataset.set_format(type='torch')\n\n    return train_dataset, eval_datasets\n\n\n# for testing\nif __name__ == '__main__':\n    import transformers\n    # from datasets import disable_caching\n    # disable_caching()\n\n    tokenizer = transformers.AutoTokenizer.from_pretrained(\n        sys.argv[1], local_files_only=True, use_fast=False, legacy=True\n    )\n    tokenizer.pad_token_id = 0\n    tokenizer.padding_side = 'right'\n    train_data1, eval_data1 = load_raw_dataset('/home/anon/data/test/txt/*.txt', tokenizer, 100, 0.5)\n    train_data2, eval_data2 = load_raw_dataset('/home/anon/data/test/json/*.jsonl', tokenizer, 100, 0.5)\n    print(len(train_data1))\n    print(len(train_data2))\n"
  },
  {
    "path": "utils/engine.py",
    "content": "from collections import deque\nimport time\n\nimport deepspeed\nimport torch\nimport transformers\nfrom deepspeed import comm as dist\nfrom deepspeed.accelerator import get_accelerator\nfrom deepspeed.runtime import utils as ds_utils\nfrom deepspeed.runtime.activation_checkpointing import checkpointing as ds_checkpointing\nfrom deepspeed.runtime.config import DeepSpeedConfig\nfrom deepspeed.runtime.pipe import p2p, schedule\nfrom deepspeed.runtime.pipe.engine import (\n    BATCH_INPUT_TIMER,\n    PIPE_RECV_GRAD_TIMER,\n    PIPE_RECV_INPUT_TIMER,\n    PIPE_SEND_GRAD_TIMER,\n    PIPE_SEND_OUTPUT_TIMER,\n    TRAIN_BATCH_TIMER,\n    PipelineEngine,\n)\nfrom deepspeed.runtime.pipe.module import LayerSpec, PipelineModule\nfrom deepspeed.runtime.pipe.schedule import (\n    BackwardPass,\n    BufferOpInstruction,\n    ForwardPass,\n    OptimizerStep,\n    PipeInstruction,\n    PipeSchedule,\n    RecvActivation,\n    RecvGrad,\n    ReduceGrads,\n    ReduceTiedGrads,\n    SendActivation,\n    SendGrad,\n    _is_even,\n    _is_odd,\n)\nfrom deepspeed.runtime.pipe.topology import ProcessTopology\nfrom deepspeed.runtime.utils import PartitionedTensor\nfrom torch import nn\n\nfrom utils.utils import eta_str, log, is_main_process\nfrom utils.dataloader import split_batch, example_to_tuple\n\n\ndef initialize(\n    args=None, model=None, config=None, **kwargs\n):\n    assert model is not None, 'deepspeed.initialize requires a model'\n\n    dist_backend = get_accelerator().communication_backend_name()\n    dist.init_distributed(dist_backend=dist_backend)\n\n    if hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None:\n        config = args.deepspeed_config\n\n    mpu = model.mpu()\n    config_class = DeepSpeedConfig(config, mpu)\n    engine = CustomPipelineEngine(\n        args=args,\n        model=model,\n        mpu=mpu,\n        config=config,\n        config_class=config_class,\n        **kwargs,\n    )\n\n    return engine, engine.optimizer\n\n\ndef unpack_accepted_rejected(example):\n    batch_size = example['input_ids'].size(0)\n    half = batch_size // 2\n    for key, tensor in list(example.items()):\n        assert tensor.size(0) == batch_size\n        example[key] = tensor[:half]\n        example['rejected_'+key] = tensor[half:]\n    return example\n\n\nclass LoadMicroBatchMultipleBuffers(PipeInstruction):\n    def __init__(self, *buffer_ids, **kwargs):\n        super().__init__(buffer_ids=buffer_ids, **kwargs)\n\n\nclass ReferenceLogitsForwardPass(BufferOpInstruction):\n    pass\n\n\nclass CustomPipelineEngine(PipelineEngine):\n    def __init__(\n        self,\n        *args,\n        lora_model=None,\n        tokenizer=None,\n        rl_config=None,\n        rejected_sampling=False,\n        rejected_sampling_max_new_tokens=1e9,\n        sampling_temperature=1.0,\n        sampling_min_p=0.,\n        sampling_temperature_last=False,\n        **kwargs\n    ):\n        super().__init__(*args, **kwargs)\n        self.total_steps = None\n        self.etas = deque()\n        self.rl_config = {}\n        # Assign list to avoid registering the nn.Module\n        self.lora_model = [lora_model]\n        self.tokenizer = tokenizer\n        self.rl_config = rl_config\n        self.rejected_sampling = rejected_sampling\n        self.rejected_sampling_max_new_tokens = rejected_sampling_max_new_tokens\n        eos_token_ids = set()\n        if self.tokenizer is not None and self.tokenizer.eos_token_id is not None:\n            eos_token_ids.add(self.tokenizer.eos_token_id)\n        model_config = self.module.model.config\n        if model_config.eos_token_id:\n            model_eos_token_ids = model_config.eos_token_id\n            if isinstance(model_eos_token_ids, int):\n                model_eos_token_ids = [model_eos_token_ids]\n            eos_token_ids.update(model_eos_token_ids)\n        self.eos_token_ids = eos_token_ids\n\n        # Sampling configuration. Only supports logits processors that don't use input_ids.\n        self.logits_processor = transformers.LogitsProcessorList()\n        temp = transformers.TemperatureLogitsWarper(float(sampling_temperature))\n        if sampling_min_p > 0:\n            self.logits_processor.append(transformers.MinPLogitsWarper(float(sampling_min_p)))\n        if sampling_temperature_last:\n            self.logits_processor.append(temp)\n        else:\n            self.logits_processor.insert(0, temp)\n\n\n    def set_dataloader(self, loader):\n        self.collate_fn = loader.collate_fn\n        if self.is_first_stage() or self.is_last_stage():\n            self.training_dataloader = loader\n            self.data_iterator = iter(self.training_dataloader)\n\n\n    def train_batch(self):\n        if not torch._C.is_grad_enabled():\n            raise RuntimeError(f'train_batch() requires gradients enabled. Use eval_batch() instead.')\n        self.timers(TRAIN_BATCH_TIMER).start()\n\n        # sequence length may change between macro batches (but not between gradient accumulation steps)\n        self.reset_activation_shape()\n\n        train_iterator = self.data_iterator\n\n        # Negative sampling\n        if self.rejected_sampling:\n            assert self.rl_config\n            self.module.eval()\n            model_inputs = self._sample_from_iterator(train_iterator, self.collate_fn)\n            dist.barrier()\n            if model_inputs is not None:\n                self.set_dataiterator(iter(model_inputs))\n            self.reset_activation_shape()\n\n        self.module.train()\n        self._compute_loss = True\n\n        # Do the work\n        if self.rl_config:\n            method = self.rl_config.get('method', None)\n            if method == 'dpo':\n                sched = DPOTrainSchedule(micro_batches=self.micro_batches, stages=self.num_stages, stage_id=self.stage_id)\n            else:\n                raise NotImplementedError(method)\n        else:\n            sched = schedule.TrainSchedule(micro_batches=self.micro_batches, stages=self.num_stages, stage_id=self.stage_id)\n        self._exec_schedule(sched)\n        agg_losses = self._aggregate_total_losses()\n        # Actual training loss is always the first item.\n        self.agg_train_loss = agg_losses[0].mean()\n\n        self.timers(TRAIN_BATCH_TIMER).stop()\n\n        if self.global_steps % self.steps_per_print() == 0:\n            if self.global_rank == 0:\n                elapsed = self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True) / 1000.0\n                iter_time = elapsed / self.steps_per_print()\n                eta = iter_time * (self.total_steps - self.global_steps)\n                self.etas.append(eta)\n                while len(self.etas) > 10:\n                    self.etas.popleft()\n                rolling_eta = sum(self.etas) / len(self.etas)\n                tput = self.train_batch_size() / iter_time\n                log(\n                    f'step: {self.global_steps:>5} / {self.total_steps:>5} '\n                    f'loss: {self.agg_train_loss:0.4f} '\n                    f'iter time (s): {iter_time:0.3f} '\n                    f'samples/sec: {tput:0.3f} '\n                    f'eta: {eta_str(rolling_eta)} '\n                )\n            else:\n                self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True)\n\n        # Monitoring\n        if self.global_rank == 0 and self.monitor.enabled:\n            self.summary_events = [\n                ('Train/Samples/train_loss', self.agg_train_loss.mean().item(), self.global_samples)\n            ]\n            self.monitor.write_events(self.summary_events)\n\n        if self.wall_clock_breakdown() and self.global_steps % self.steps_per_print() == 0:\n            self.timers.log(\n                [\n                    PIPE_SEND_OUTPUT_TIMER,\n                    PIPE_SEND_GRAD_TIMER,\n                    PIPE_RECV_INPUT_TIMER,\n                    PIPE_RECV_GRAD_TIMER,\n                ]\n            )\n\n        # Restore the training iterator\n        self.set_dataiterator(train_iterator)\n\n        return agg_losses\n\n    def eval_batch(self, data_iter):\n        # sequence length may change between macro batches (but not between gradient accumulation steps)\n        self.reset_activation_shape()\n\n        self.module.eval()\n        self._compute_loss = True\n\n        # Use the provided data iterator\n        train_iterator = self.data_iterator\n        self.set_dataiterator(data_iter)\n\n        # Negative sampling\n        if self.rejected_sampling:\n            assert self.rl_config\n            dist.barrier()\n            model_inputs = self._sample_from_iterator(data_iter, self.collate_fn)\n            if model_inputs is not None:\n                self.set_dataiterator(iter(model_inputs))\n            self.reset_activation_shape()\n\n        # Do the work\n        if self.rl_config:\n            method = self.rl_config.get('method', None)\n            if method == 'dpo':\n                sched = DPOInferenceSchedule(micro_batches=self.micro_batches, stages=self.num_stages, stage_id=self.stage_id)\n            else:\n                raise NotImplementedError(method)\n        else:\n            sched = schedule.InferenceSchedule(\n                micro_batches=self.micro_batches, stages=self.num_stages, stage_id=self.stage_id\n            )\n\n        # prevent dead-lock with multiple evals sequence\n        dist.barrier()\n\n        with torch.no_grad():\n            self._exec_schedule(sched)\n\n        # list of losses\n        agg_eval_losses = self._aggregate_total_losses()\n\n        if self.global_rank == 0 and self.monitor.enabled:\n            self.summary_events = [('Train/Samples/eval_loss', agg_eval_losses[0].mean().item(), self.global_samples)]\n            self.monitor.write_events(self.summary_events)\n\n        # Restore the training iterator\n        self.set_dataiterator(train_iterator)\n\n        return agg_eval_losses\n\n    def sample_batch(self, prompts):\n        assert isinstance(prompts, (list, tuple))\n        self.reset_activation_shape()\n        self.module.eval()\n        self.module.set_sampling_mode(True)\n        original_micro_batches = self.micro_batches\n        self.micro_batches = len(prompts)\n        dist.barrier()\n\n        if self.is_first_stage():\n            # Tokenizer returns dict with 'input_ids', 'attention_mask' keys.\n            # Tensors have batch dimension because we pass list of prompts.\n            examples = []\n            for prompt in prompts:\n                if not isinstance(prompt, (list, tuple)):\n                    prompt = [prompt]\n                examples.append(self.tokenizer(prompt, return_tensors='pt', padding=True))\n        else:\n            examples = None\n        with torch.no_grad():\n            examples = self._exec_sampling_schedule(examples)\n        if self.is_first_stage():\n            text = [self.tokenizer.batch_decode(example['input_ids']) for example in examples]\n        else:\n            text = None\n        self.micro_batches = original_micro_batches\n        self.module.set_sampling_mode(False)\n        return text\n\n\n    def _sample_from_iterator(self, data_iter, collate_fn):\n        if data_iter is not None:\n            examples = [unpack_accepted_rejected(next(data_iter)) for _ in range(self.micro_batches)]\n            # TODO: allow configuring this\n            max_total_tokens = max(example['input_ids'].size(1) for example in examples) * 2\n        else:\n            examples = None\n            # This is okay, max_total_tokens is only checked on the first stage.\n            max_total_tokens = None\n        self.module.eval()\n        self.module.set_sampling_mode(True)\n        dist.barrier()\n        with torch.no_grad():\n            examples = self._exec_sampling_schedule(examples, feature_prefix='rejected_', max_total_tokens=max_total_tokens)\n        if is_main_process():\n            input_ids = examples[0]['rejected_input_ids'][0]\n            attention_mask = examples[0]['rejected_attention_mask'][0]\n            start = torch.argmax(attention_mask)\n            end = len(attention_mask) - torch.argmax(torch.flip(attention_mask, (0,)))\n            text = self.tokenizer.decode(input_ids[start:end])\n            print(f'Example of sampled rejected completion:\\n{text}')\n        self.module.set_sampling_mode(False)\n        if examples is not None:\n            batch = collate_fn(examples, gradient_accumulation_steps=self.micro_batches)\n            model_inputs = [example_to_tuple(micro_batch) for micro_batch in split_batch(batch, self.micro_batches)]\n            return model_inputs\n        else:\n            return None\n\n\n    def _aggregate_total_losses(self):\n        all_agg_outputs = []\n        # gather each output for all the gradient accumulation steps\n        grouped_outputs = [list(x) for x in zip(*self.fwd_outputs)]\n        # if any are scalar, make them dim 1 so we can concat across DP ranks\n        for outputs in grouped_outputs:\n            for i, output in enumerate(outputs):\n                if output.dim() == 0:\n                    outputs[i] = torch.unsqueeze(output, 0)\n\n        if self.is_last_stage():\n            agg_sizes = []\n            # loop to gather all the outputs across DP ranks\n            for outputs in grouped_outputs:\n                # concat all the grad_accum_steps\n                concat_outputs = torch.cat(outputs)\n                if self.is_data_parallel:\n                    # might be different sizes across DP ranks, so, gather all the sizes\n                    sizes = [None] * self.grid.get_data_parallel_world_size()\n                    torch.distributed.all_gather_object(\n                        sizes, concat_outputs.size(), group=self.grid.get_data_parallel_group()\n                    )\n                    # once we know all the sizes we can gather the results across DP ranks\n                    gather_result = [torch.zeros(size).to(self.device) for size in sizes]\n                    dist.all_gather(gather_result, concat_outputs, group=self.grid.get_data_parallel_group())\n                    # and finally, concat\n                    agg_output = torch.cat(gather_result)\n                else:\n                    agg_output = concat_outputs\n                agg_sizes.append(agg_output.size())\n                all_agg_outputs.append(agg_output)\n\n            # send the sizes, then broadcast to the PP ranks\n            if self.is_pipe_parallel:\n                torch.distributed.broadcast_object_list(\n                    [agg_sizes], src=self.global_rank, group=self.grid.get_pipe_parallel_group()\n                )\n                for agg_output in all_agg_outputs:\n                    dist.broadcast(tensor=agg_output, src=self.global_rank, group=self.grid.get_pipe_parallel_group())\n        else:\n            # get the outputs from the last stage\n            src_rank = self.grid.stage_to_global(self.num_stages - 1)\n            assert src_rank in self.grid.pp_group\n            result = [None]\n            torch.distributed.broadcast_object_list(result, src=src_rank, group=self.grid.get_pipe_parallel_group())\n            agg_sizes = result[0]\n            for agg_size in agg_sizes:\n                agg_output = torch.zeros(agg_size).to(self.device)\n                dist.broadcast(tensor=agg_output, src=src_rank, group=self.grid.get_pipe_parallel_group())\n                all_agg_outputs.append(agg_output)\n\n        return all_agg_outputs\n\n    # We override this to handle the model returning a list of \"losses\", but only doing backprop on the first.\n    def _exec_forward_pass(self, buffer_id):\n        self.tput_timer.start()\n        self.mem_status('BEFORE FWD', reset_max=True)\n\n        if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple):\n            inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id])\n        else:\n            inputs = self.pipe_buffers['inputs'][buffer_id].clone()\n\n        # collect the partitioned input from the previous stage\n        if self.is_pipe_partitioned and not self.is_first_stage():\n            part_input = PartitionedTensor.from_meta(\n                meta=inputs[0], local_part=inputs[1], group=self.grid.get_slice_parallel_group()\n            )\n\n            inputs = (part_input.full(), *inputs[2:])\n            inputs[0].requires_grad = True\n            # skip mask\n            # inputs[1].requires_grad = True\n            part_input = None\n            inputs = inputs[0] if len(inputs) == 1 else inputs\n            self.pipe_buffers['inputs'][buffer_id] = inputs\n\n        # inputs has no gradient because it is from a cloned tensor\n        outputs = super(PipelineEngine, self).forward(inputs)\n\n        # Reset activation checkpointing buffers.\n        # Need to call this between evaluation iterations\n        if not self.module.training:\n            ds_checkpointing.reset()\n\n        # Partition the outputs if we are not the last stage\n        if self.is_pipe_partitioned and not self.is_last_stage():\n            if isinstance(outputs, tuple):\n                first_output = outputs[0]\n                # TODO: Improve pipe partitioning to pass multiple tensors that require grads\n                assert all(torch.is_tensor(elt) and elt.requires_grad is False for elt in outputs[1:])\n                outputs_tail = outputs[1:]\n            elif torch.is_tensor(outputs):\n                first_output = outputs\n                outputs_tail = []\n            else:\n                raise ValueError('expecting a tensor or a tuple of tensors')\n            part = PartitionedTensor(tensor=first_output, group=self.grid.get_slice_parallel_group())\n            # Clear the large output data, but save the computation graph\n            first_output.data = torch.zeros(1)\n            self.pipe_buffers['output_tensors'][buffer_id] = first_output\n            # Inject the partitioned tensor into the output before sending\n            outputs = (part.to_meta(), part.data(), *outputs_tail)\n            part = None\n\n        self.pipe_buffers['outputs'][buffer_id] = outputs\n\n        # Optionally compute loss on the last device\n        if self.is_last_stage():\n            if self._compute_loss and self.module.loss_fn is not None:\n                labels = self.pipe_buffers['labels'][buffer_id]\n                losses = self.module.loss_fn(outputs, labels)\n            else:\n                # Some models just return loss from forward()\n                losses = outputs\n            if self.eval_return_logits:\n                self.outputs = outputs\n            if isinstance(losses, torch.Tensor):\n                self.loss = losses\n                self.fwd_outputs.append([self.loss.detach()])\n            else:\n                self.loss = losses[0]\n                self.fwd_outputs.append([l.detach() for l in losses])\n\n    def _exec_load_micro_batch_multiple_buffers(self, buffer_ids):\n        if self.wall_clock_breakdown():\n            self.timers(BATCH_INPUT_TIMER).start()\n\n        batch = self._next_batch()\n\n        if self.is_first_stage():\n            loaded = None\n            if torch.is_tensor(batch[0]):\n                loaded = batch[0].clone().to(self.device).detach()\n                if (\n                    self._config.pipeline['activation_checkpoint_interval'] > 0\n                    and self._config.pipeline['use_reentrant']\n                ):\n                    loaded.requires_grad = loaded.is_floating_point()\n            else:\n                assert isinstance(batch[0], (tuple, list))\n                # Assume list or tuple\n                loaded = []\n                for x in batch[0]:\n                    assert torch.is_tensor(x)\n                    mine = x.clone().detach().to(self.device)\n                    if (\n                        self._config.pipeline['activation_checkpoint_interval'] > 0\n                        and self._config.pipeline['use_reentrant']\n                    ):\n                        mine.requires_grad = mine.is_floating_point()\n                    loaded.append(mine)\n                loaded = tuple(loaded)\n\n            for buffer_id in buffer_ids:\n                self.pipe_buffers['inputs'][buffer_id] = loaded\n\n        if self.is_last_stage():\n            loaded = batch[1]\n            if torch.is_tensor(batch[1]):\n                loaded = batch[1].to(self.device)\n            # XXX: torch 1.6.0 DataLoader will auto convert tuple to list\n            elif isinstance(batch[1], (tuple, list)):\n                loaded = []\n                for x in batch[1]:\n                    assert torch.is_tensor(x)\n                    x = x.to(self.device).detach()\n                    loaded.append(x)\n                loaded = tuple(loaded)\n\n            for buffer_id in buffer_ids:\n                self.pipe_buffers['labels'][buffer_id] = loaded\n\n        if self.wall_clock_breakdown():\n            self.timers(BATCH_INPUT_TIMER).stop()\n\n    @torch.no_grad()\n    def _exec_reference_logits_forward_pass(self, buffer_id):\n        self.lora_model[0].disable_adapter_layers()\n        self.module.set_dpo_reference_mode(True)\n        if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple):\n            inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id])\n        else:\n            inputs = self.pipe_buffers['inputs'][buffer_id].clone()\n\n        # collect the partitioned input from the previous stage\n        if self.is_pipe_partitioned and not self.is_first_stage():\n            if self.pipe_partition_input_meta_cache is None:\n                self.pipe_partition_input_meta_cache = inputs[0].to('cpu')\n            part_input = PartitionedTensor.from_meta(\n                meta=self.pipe_partition_input_meta_cache,\n                local_part=inputs[1],\n                group=self.grid.get_slice_parallel_group(),\n            )\n\n            inputs = (part_input.full(), *inputs[2:])\n            inputs[0].requires_grad = True\n            # skip mask\n            # inputs[1].requires_grad = True\n            part_input = None\n            inputs = inputs[0] if len(inputs) == 1 else inputs\n            self.pipe_buffers['inputs'][buffer_id] = inputs\n\n        # inputs has no gradient because it is from a cloned tensor\n        outputs = super(PipelineEngine, self).forward(inputs)\n\n        # Reset activation checkpointing buffers.\n        # Need to call this between evaluation iterations\n        if not self.module.training:\n            ds_checkpointing.reset()\n\n        # Partition the outputs if we are not the last stage\n        if self.is_pipe_partitioned and not self.is_last_stage():\n            if isinstance(outputs, tuple):\n                first_output = outputs[0]\n                # TODO: Improve pipe partitioning to pass multiple tensors that require grads\n                assert all(torch.is_tensor(elt) and elt.requires_grad is False for elt in outputs[1:])\n                outputs_tail = outputs[1:]\n            elif torch.is_tensor(outputs):\n                first_output = outputs\n                outputs_tail = []\n            else:\n                raise ValueError('expecting a tensor or a tuple of tensors')\n            part = PartitionedTensor(tensor=first_output, group=self.grid.get_slice_parallel_group())\n            # Clear the large output data, but save the computation graph\n            first_output.data = torch.zeros(1, device=first_output.data.device)\n            self.pipe_buffers['output_tensors'][buffer_id] = first_output\n            # Inject the partitioned tensor into the output before sending\n            outputs = (part.to_meta(), part.data(), *outputs_tail)\n            part = None\n\n        self.pipe_buffers['outputs'][buffer_id] = outputs\n        self.lora_model[0].enable_adapter_layers()\n        self.module.set_dpo_reference_mode(False)\n\n    def _exec_send_micro_batch_id(self, send_micro_batch_id):\n        assert isinstance(send_micro_batch_id, int)\n        if self.num_stages == 1:\n            return send_micro_batch_id\n        send_micro_batch_id = torch.tensor(send_micro_batch_id, device=self.device)\n        recv_micro_batch_id = torch.tensor(-1, device=self.device)\n        if _is_even(self.stage_id):\n            if not self.is_last_stage():\n                p2p.send(send_micro_batch_id, self.next_stage)\n            if not self.is_first_stage():\n                p2p.recv(recv_micro_batch_id, self.prev_stage)\n        else:\n            if not self.is_first_stage():\n                p2p.recv(recv_micro_batch_id, self.prev_stage)\n            if not self.is_last_stage():\n                p2p.send(send_micro_batch_id, self.next_stage)\n        # last stage sends to first stage\n        if self.is_first_stage():\n            p2p.recv(recv_micro_batch_id, self.num_stages - 1)\n        if self.is_last_stage():\n            p2p.send(send_micro_batch_id, 0)\n        return recv_micro_batch_id.item()\n\n    def _exec_load_micro_batch_for_sampling(self, buffer_id, inputs):\n        loaded = (\n            inputs['input_ids'],\n            inputs['attention_mask'],\n            torch.tensor([0]),  # labels must be provided, so use a dummy\n        )\n        loaded = tuple(x.clone().detach().to(self.device) for x in loaded)\n        self.pipe_buffers['inputs'][buffer_id] = loaded\n\n    @torch.no_grad()\n    def _exec_sampling_forward_pass(self, buffer_id):\n        if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple):\n            inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id])\n        else:\n            inputs = self.pipe_buffers['inputs'][buffer_id].clone()\n\n        # collect the partitioned input from the previous stage\n        if self.is_pipe_partitioned and not self.is_first_stage():\n            if self.pipe_partition_input_meta_cache is None:\n                self.pipe_partition_input_meta_cache = inputs[0].to('cpu')\n            part_input = PartitionedTensor.from_meta(\n                meta=self.pipe_partition_input_meta_cache,\n                local_part=inputs[1],\n                group=self.grid.get_slice_parallel_group(),\n            )\n\n            inputs = (part_input.full(), *inputs[2:])\n            inputs[0].requires_grad = True\n            # skip mask\n            # inputs[1].requires_grad = True\n            part_input = None\n            inputs = inputs[0] if len(inputs) == 1 else inputs\n            self.pipe_buffers['inputs'][buffer_id] = inputs\n\n        # inputs has no gradient because it is from a cloned tensor\n        outputs = super(PipelineEngine, self).forward(inputs)\n\n        # Reset activation checkpointing buffers.\n        # Need to call this between evaluation iterations\n        if not self.module.training:\n            ds_checkpointing.reset()\n\n        # Partition the outputs if we are not the last stage\n        if self.is_pipe_partitioned and not self.is_last_stage():\n            if isinstance(outputs, tuple):\n                first_output = outputs[0]\n                # TODO: Improve pipe partitioning to pass multiple tensors that require grads\n                assert all(torch.is_tensor(elt) and elt.requires_grad is False for elt in outputs[1:])\n                outputs_tail = outputs[1:]\n            elif torch.is_tensor(outputs):\n                first_output = outputs\n                outputs_tail = []\n            else:\n                raise ValueError('expecting a tensor or a tuple of tensors')\n            part = PartitionedTensor(tensor=first_output, group=self.grid.get_slice_parallel_group())\n            # Clear the large output data, but save the computation graph\n            first_output.data = torch.zeros(1, device=first_output.data.device)\n            self.pipe_buffers['output_tensors'][buffer_id] = first_output\n            # Inject the partitioned tensor into the output before sending\n            outputs = (part.to_meta(), part.data(), *outputs_tail)\n            part = None\n\n        self.pipe_buffers['outputs'][buffer_id] = outputs\n\n    def _sample_from_logits(self, buffer_id):\n        logits = self.pipe_buffers['outputs'][buffer_id].squeeze(1)\n        logits = self.logits_processor(None, logits)\n        probs = torch.nn.functional.softmax(logits, dim=-1)\n        input_ids = torch.multinomial(probs, num_samples=1)\n        # Logically you would squeeze(1) to remove the multinomial num_samples dimension, then\n        # unsqueeze(1) to add back the sequence_length dimension. But those just cancel out.\n        return input_ids\n\n    def _valid_stage(self, stage_id):\n        return 0 <= stage_id < self.num_stages\n\n    def _valid_micro_batch(self, micro_batch_id):\n        return 0 <= micro_batch_id < self.micro_batches\n\n    def _exec_sampling_schedule(self, examples, feature_prefix='', max_total_tokens=1e9):\n        start = time.time()\n        input_ids_key = f'{feature_prefix}input_ids'\n        attention_mask_key = f'{feature_prefix}attention_mask'\n        labels_key = f'{feature_prefix}labels'\n        # Reserve and reset buffers.\n        self._reserve_pipe_buffers(2)\n        self.fwd_outputs = []\n        eos_token_ids = torch.tensor(list(self.eos_token_ids))\n        finished = False\n\n        if self.is_first_stage():\n            num_batches_done = 0\n            num_batches = len(examples)\n            queue = deque()\n            for i, example in enumerate(examples):\n                example['done'] = torch.tensor([False]*example[input_ids_key].size(0))\n                example['num_new_tokens'] = 0\n                queue.append((i, {\n                    'input_ids': example[input_ids_key],\n                    'attention_mask': example[attention_mask_key],\n                }))\n\n        step_id = 0\n        micro_batch_id = -1\n        prev_micro_batch_id = -1\n        while not finished:\n            # Alternate send/recv buffers\n            if _is_even(self.stage_id):\n                recv_buf = step_id % 2\n                send_buf = (step_id + 1) % 2\n            else:\n                recv_buf = (step_id + 1) % 2\n                send_buf = step_id % 2\n\n            # Load from the queue on the first stage.\n            if self.is_first_stage():\n                if len(queue) > 0:\n                    micro_batch_id, inputs = queue.popleft()\n                    self._exec_load_micro_batch_for_sampling(recv_buf, inputs)\n                else:\n                    micro_batch_id = -1\n\n            # Send / receive activations if needed.\n            if _is_even(self.stage_id):\n                if self._valid_stage(self.next_stage):\n                    if self._valid_micro_batch(prev_micro_batch_id):\n                        self._exec_send_activations(send_buf)\n                if self._valid_stage(self.prev_stage):\n                    if self._valid_micro_batch(micro_batch_id):\n                        self._exec_recv_activations(recv_buf)\n            else:\n                if self._valid_stage(self.prev_stage):\n                    if self._valid_micro_batch(micro_batch_id):\n                        self._exec_recv_activations(recv_buf)\n                if self._valid_stage(self.next_stage):\n                    if self._valid_micro_batch(prev_micro_batch_id):\n                        self._exec_send_activations(send_buf)\n\n            # Send micro_batch_id to next stage. Last stage wraps around and sends to first stage.\n            prev_micro_batch_id = micro_batch_id\n            micro_batch_id = self._exec_send_micro_batch_id(micro_batch_id)\n\n            # Run forward().\n            # Note that prev_micro_batch_id is actually the micro_batch_id of the current step.\n            if self._valid_micro_batch(prev_micro_batch_id):\n                self.model.set_cache(prev_micro_batch_id)\n                self._exec_sampling_forward_pass(recv_buf)\n                if self.is_last_stage():\n                    input_ids = self._sample_from_logits(recv_buf)\n                    if self.num_stages > 1:\n                        p2p.send(input_ids, 0)\n\n            # First stage got a valid micro_batch_id from the last stage. Receive the input_ids and process them.\n            if self.is_first_stage() and self._valid_micro_batch(micro_batch_id):\n                example = examples[micro_batch_id]\n                batch_size = example[input_ids_key].size(0)\n\n                if self.num_stages > 1:\n                    input_ids = torch.full((batch_size, 1), -1, device=self.device)\n                    p2p.recv(input_ids, self.num_stages - 1)\n                assert input_ids.size(-1) == 1, input_ids.shape\n\n                if example['num_new_tokens'] >= self.rejected_sampling_max_new_tokens:\n                    finished = True\n                if example[input_ids_key].size(1) >= max_total_tokens:\n                    finished = True\n\n                if not finished:\n                    input_ids = input_ids.to('cpu')\n                    prev_done = example['done']\n                    # Determine which items in the batch are done generating.\n                    done = prev_done | (input_ids == eos_token_ids).any(-1)\n                    example['done'] = done\n                    batch_done = done.all().item()\n                    # Output pad token and 0 attention mask for items in the batch that are already done.\n                    prev_done_reshaped = prev_done.unsqueeze(-1)\n                    input_ids = torch.where(prev_done_reshaped, self.tokenizer.pad_token_id, input_ids)\n                    attention_mask_extension = torch.where(prev_done_reshaped, 0, 1)\n                    labels_extention = torch.where(prev_done_reshaped, -100, input_ids)\n                    input_ids = torch.cat([example[input_ids_key], input_ids], dim=-1)\n                    example[input_ids_key] = input_ids\n                    if labels_key in example:\n                        example[labels_key] = torch.cat([example[labels_key], labels_extention], dim=-1)\n                    attention_mask = torch.cat([example[attention_mask_key], attention_mask_extension], dim=-1)\n                    example[attention_mask_key] = attention_mask\n                    example['num_new_tokens'] += 1\n                    if batch_done:\n                        num_batches_done += 1\n                        finished = (num_batches_done == num_batches)\n                    else:\n                        # Model needs full attention mask, but only most recent sampled input_id.\n                        queue.append(\n                            (\n                                micro_batch_id,\n                                {\n                                    'input_ids': input_ids[..., -1:],\n                                    'attention_mask': attention_mask,\n                                },\n                            )\n                        )\n\n            # Broadcast finished from first stage to all other stages so they can exit the loop.\n            src_rank = self.grid.stage_to_global(0)\n            finished = [finished] if self.is_first_stage() else [None]\n            torch.distributed.broadcast_object_list(\n                finished, src=src_rank, group=self.grid.get_pipe_parallel_group()\n            )\n            finished = finished[0]\n            step_id += 1\n            # end while loop\n\n        if self.is_first_stage():\n            total_new_tokens = 0\n            for example in examples:\n                total_new_tokens += example['num_new_tokens'] * batch_size\n                del example['done']\n                del example['num_new_tokens']\n\n            if is_main_process():\n                duration = time.time() - start\n                tps = total_new_tokens / duration\n                print(f'Total sampling time: {duration:.1f}, average tok/s: {tps:.1f}')\n\n        dist.barrier()\n        return examples\n\n    # make our forward pass method apply\n    PipelineEngine._INSTRUCTION_MAP[schedule.ForwardPass] = _exec_forward_pass\n    PipelineEngine._INSTRUCTION_MAP[LoadMicroBatchMultipleBuffers] = _exec_load_micro_batch_multiple_buffers\n    PipelineEngine._INSTRUCTION_MAP[ReferenceLogitsForwardPass] = _exec_reference_logits_forward_pass\n\n\nclass ColumnMajorParallelTopology(ProcessTopology):\n    \"\"\"\n    A topology specialisation for hybrid data+pipeline parallelism optimized for LoRA training:\n    - Sends high-volume \"per token\" hidden states over PCIe/NVLink.\n    - Sends lower-volume \"per step\" LoRA gradient reductions over Ethernet/InfiniBand.\n    \"\"\"\n\n    def __init__(self, num_pp, num_dp):\n        # Swap the axes and dims to change the rank mapping\n        super().__init__(axes=['data', 'pipe'], dims=[num_dp, num_pp])\n\n\nclass CustomPipelineModule(PipelineModule):\n    def __init__(self, layers, use_column_major_topology, model=None, **kwargs):\n        # Assign to list to avoid registering the nn.Module\n        self._model = [model]\n        # Hybrid LoRA data+pipeline parallelism may want to use \"column-major\" layout\n        if use_column_major_topology:\n            world_size = dist.get_world_size()\n            num_stages = kwargs.get('num_stages')\n            if num_stages > 1 and world_size > 1:\n                assert world_size % num_stages == 0, (\n                    f'world_size ({world_size}) must be divisible by num_stages ({num_stages})'\n                )\n                num_dp = world_size // num_stages\n                kwargs['topology'] = ColumnMajorParallelTopology(num_pp=num_stages, num_dp=num_dp)\n        super().__init__(layers, **kwargs)\n\n    @property\n    def model(self):\n        return self._model[0]\n\n    def set_dpo_reference_mode(self, dpo_reference_mode):\n        self.model.set_dpo_reference_mode(dpo_reference_mode)\n\n    def set_sampling_mode(self, sampling_mode):\n        self.model.set_sampling_mode(sampling_mode)\n\n    def _partition_layers(self, method='uniform'):\n        num_stages = self._topo.get_dim('pipe')\n        stage_id = self._topo.get_coord(self.global_rank).pipe\n\n        if self.global_rank == 0:\n            print(f'Partitioning pipeline stages with method {method}')\n\n        method = method.lower()\n\n        estimated_sizes = None\n        # Each stage gets a simple uniform number of layers.\n        if method == 'uniform':\n            num_layers = len(self._layer_specs)\n            self.parts = ds_utils.partition_uniform(num_items=num_layers, num_parts=num_stages)\n        elif method == 'parameters':\n            param_counts = self._count_layer_params()\n            self.parts = ds_utils.partition_balanced(weights=param_counts, num_parts=num_stages)\n        elif method.startswith('type:'):\n            layertype = method.split(':')[1]\n            binary_weights = [0] * len(self._layer_specs)\n            for idx in self._find_layer_type(layertype):\n                binary_weights[idx] = 1\n            self.parts = ds_utils.partition_balanced(weights=binary_weights, num_parts=num_stages)\n        elif method == 'profile':\n            raise NotImplementedError(f'Partitioning method {method} not implemented.')\n        elif method == 'estimated_size':\n            estimated_sizes = [getattr(l, 'estimated_size', 0) for l in self._layer_specs]\n            self.parts = ds_utils.partition_balanced(weights=estimated_sizes, num_parts=num_stages)\n        else:\n            raise NotImplementedError(f'Partitioning method {method} not implemented.')\n\n        # Print some information on the partitioning.\n        if self.global_rank == 0:\n            for stage in range(num_stages):\n                start = self.parts[stage]\n                stop = self.parts[stage + 1]\n                print(f'stage={stage} layers={stop - start}')\n                for idx, layer in enumerate(self._layer_specs[start:stop]):\n                    name = str(layer)\n                    if isinstance(layer, LayerSpec):\n                        name = layer.typename.__name__\n                    if isinstance(layer, nn.Module):\n                        name = layer.__class__.__name__\n                    else:\n                        try:\n                            name = layer.__name__\n                        except AttributeError:\n                            pass\n                    logstr = f'    {idx + start:2d}: {name}'\n                    if estimated_sizes:\n                        es = estimated_sizes[idx + start]\n                        logstr += f', estimated size: {es}'\n                    print(logstr)\n            if self.loss_fn:\n                try:\n                    print(f'  loss: {self.loss_fn.__name__}')\n                except AttributeError:\n                    print(f'  loss: {self.loss_fn.__class__.__name__}')\n        deepspeed.comm.barrier()\n\n        self._set_bounds(start=self.parts[stage_id], stop=self.parts[stage_id + 1])\n\n\nclass DPOTrainSchedule(PipeSchedule):\n    \"\"\"Train schedule for DPO. Does an extra forward pass for the reference logits.\"\"\"\n\n    def steps(self):\n        prev_micro_batch_id = -1\n        total_steps = 2 * (self.micro_batches + self.stages - 1)\n        forward_step_id = 0\n        ref_logits_buf = self.num_pipe_buffers() - 1\n        for step_id in range(total_steps):\n            # Map the step of the pipeline to the micro-batch id and also whether it is a\n            # forward or backward pass step.\n            micro_batch_id, is_forward = self._step_to_micro_batch(step_id)\n\n            if self._valid_micro_batch(prev_micro_batch_id):\n                prev_buffer = self._buffer_idx(prev_micro_batch_id)\n            if self._valid_micro_batch(micro_batch_id):\n                curr_buffer = self._buffer_idx(micro_batch_id)\n\n            cmds = []\n\n            # Exchange activations\n            if is_forward:\n                if self._valid_micro_batch(prev_micro_batch_id) and self._valid_stage(self.prev_stage):\n                    cmds.append(SendGrad(prev_buffer))\n                if self._valid_micro_batch(micro_batch_id) and self._valid_stage(self.prev_stage):\n                    cmds.append(RecvActivation(ref_logits_buf))\n                    cmds.append(RecvActivation(curr_buffer))\n            else:\n                if self._valid_micro_batch(micro_batch_id) and self._valid_stage(self.next_stage):\n                    cmds.append(RecvGrad(curr_buffer))\n                if self._valid_micro_batch(prev_micro_batch_id) and self._valid_stage(self.next_stage):\n                    cmds.append(SendActivation(ref_logits_buf))\n                    cmds.append(SendActivation(prev_buffer))\n\n            # First/last stage loads\n            if self.stage_id == 0 or self.stage_id == self.stages - 1:\n                if is_forward and self._valid_micro_batch(micro_batch_id):\n                    # Load for normal forward and reference logits forward.\n                    cmds.append(LoadMicroBatchMultipleBuffers(curr_buffer, ref_logits_buf))\n\n            # Computation\n            if self._valid_micro_batch(micro_batch_id):\n                if is_forward:\n                    # Reference logits forward.\n                    cmds.append(ReferenceLogitsForwardPass(ref_logits_buf))\n                    cmds.append(ForwardPass(curr_buffer))\n                    forward_step_id += 1\n                else:\n                    cmds.append(BackwardPass(curr_buffer))\n\n            # Model step at the end of the batch\n            if step_id == total_steps - 1:\n                cmds.append(ReduceTiedGrads())\n                cmds.append(ReduceGrads())\n                cmds.append(OptimizerStep())\n\n            # Prepare state for next time\n            prev_micro_batch_id = micro_batch_id\n            yield cmds\n\n    def num_pipe_buffers(self):\n        buffers = min(self.stages - self.stage_id, self.micro_batches)\n        # +1 buffer for reference logits forward pass.\n        # Unlike inference, we only need 1 buffer, since alternating forward/backward passes means a stage\n        # is never sending and receiving activations on the same step.\n        return max(2, buffers) + 1\n\n    def _step_to_micro_batch(self, step_id):\n        if _is_even(step_id) and _is_even(self.stage_id):\n            micro_batch_id = self._even_step_forward_id(step_id)\n            is_forward = True\n\n        elif _is_odd(step_id) and _is_odd(self.stage_id):\n            micro_batch_id = self._odd_step_forward_id(step_id)\n            is_forward = True\n\n        elif _is_even(step_id) and _is_odd(self.stage_id):\n            micro_batch_id = self._even_step_backward_id(step_id)\n            is_forward = False\n\n        elif _is_odd(step_id) and _is_even(self.stage_id):\n            micro_batch_id = self._odd_step_backward_id(step_id)\n            is_forward = False\n\n        else:\n            assert False\n\n        return micro_batch_id, is_forward\n\n    def _even_step_forward_id(self, step_id):\n        base = step_id // 2\n        micro_batch_id = int(base - self.stage_id // 2)\n        return micro_batch_id\n\n    def _odd_step_forward_id(self, step_id):\n        base = (step_id - 1) // 2\n        micro_batch_id = int(base - self.stage_id // 2)\n        return micro_batch_id\n\n    def _even_step_backward_id(self, step_id):\n        base = step_id // 2\n        micro_batch_id = int(base - self.stages + (self.stage_id + 1) // 2)\n        return micro_batch_id\n\n    def _odd_step_backward_id(self, step_id):\n        base = ((step_id - 1) // 2) - self.stages + 1\n        micro_batch_id = int(base + self.stage_id // 2)\n        return micro_batch_id\n\n    # Override to account for the extra buffer used for reference logit forward pass.\n    def _buffer_idx(self, micro_batch_id):\n        assert self._valid_micro_batch(micro_batch_id)\n        return micro_batch_id % (self.num_pipe_buffers() - 1)\n\n\nclass DPOInferenceSchedule(PipeSchedule):\n    def steps(self):\n        total_steps = self.micro_batches + self.stages - 1\n        for step_id in range(total_steps):\n            cmds = []\n            micro_batch_id = step_id - self.stage_id\n\n            # Alternate send/recv buffers\n            if _is_even(self.stage_id):\n                recv_buf = step_id % 2\n                send_buf = (step_id + 1) % 2\n            else:\n                recv_buf = (step_id + 1) % 2\n                send_buf = step_id % 2\n\n            ref_recv_buf = recv_buf + 2\n            ref_send_buf = send_buf + 2\n\n            if self.is_first_stage or self.is_last_stage:\n                if self._valid_micro_batch(micro_batch_id):\n                    cmds.append(LoadMicroBatchMultipleBuffers(recv_buf, ref_recv_buf))\n\n            if _is_even(self.stage_id):\n                if self._valid_stage(self.next_stage):\n                    if self._valid_micro_batch(micro_batch_id - 1):\n                        cmds.append(SendActivation(ref_send_buf))\n                        cmds.append(SendActivation(send_buf))\n                if self._valid_stage(self.prev_stage):\n                    if self._valid_micro_batch(micro_batch_id):\n                        cmds.append(RecvActivation(ref_recv_buf))\n                        cmds.append(RecvActivation(recv_buf))\n            else:\n                if self._valid_stage(self.prev_stage):\n                    if self._valid_micro_batch(micro_batch_id):\n                        cmds.append(RecvActivation(ref_recv_buf))\n                        cmds.append(RecvActivation(recv_buf))\n\n                if self._valid_stage(self.next_stage):\n                    if self._valid_micro_batch(micro_batch_id - 1):\n                        cmds.append(SendActivation(ref_send_buf))\n                        cmds.append(SendActivation(send_buf))\n\n            if self._valid_micro_batch(micro_batch_id):\n                cmds.append(ReferenceLogitsForwardPass(ref_recv_buf))\n                cmds.append(ForwardPass(recv_buf))\n\n            yield cmds\n\n    def num_pipe_buffers(self):\n        return 4\n"
  },
  {
    "path": "utils/hqq_utils.py",
    "content": "from dataclasses import asdict, dataclass, field\nfrom typing import Any\n\nimport torch\nimport transformers\nfrom hqq.core import quantize as hqq_quantize\nfrom torch import nn\n\nimport peft\nfrom utils.utils import DTYPE_MAP\n\n\n# Monkeypatch PEFT so that target_modules='all-linear' targets the HQQLinear layers, which are not\n# subclasses of nn.Linear, unlike BNB.\ndef _maybe_include_all_linear_layers(peft_config: peft.PeftConfig, model: nn.Module) -> peft.PeftConfig:\n    \"\"\"\n    Helper function to update `target_modules` to all linear/Conv1D layers if provided as 'all-linear'. Adapted from\n    the QLoRA repository: https://github.com/artidoro/qlora/blob/main/qlora.py\n    \"\"\"\n\n    # if `target_modules` is a string, convert to lower case and check if it matches \"all-linear\"\n    if not (\n        isinstance(peft_config.target_modules, str)\n        and peft_config.target_modules.lower() == peft.tuners.tuners_utils.INCLUDE_LINEAR_LAYERS_SHORTHAND\n    ):\n        return peft_config\n\n    if not isinstance(model, transformers.PreTrainedModel):\n        raise ValueError(\n            f'Only instances of PreTrainedModel support `target_modules={peft.tuners.tuners_utils.INCLUDE_LINEAR_LAYERS_SHORTHAND!r}`'\n        )\n\n    # add HQQLinear\n    linear_classes = (torch.nn.Linear, transformers.pytorch_utils.Conv1D, hqq_quantize.HQQLinear)\n\n    linear_module_names = set()\n    for name, module in model.named_modules():\n        # match with all linear classes.\n        if isinstance(module, linear_classes):\n            names = name.rsplit('.', 1)[-1]  # get the base name\n            linear_module_names.add(names)\n\n    # ignore the last classification head for text generation models\n    output_emb = model.get_output_embeddings()\n    if output_emb is not None:\n        last_module_name = [name for name, module in model.named_modules() if module is output_emb][0]\n        linear_module_names -= {last_module_name}\n    peft_config.target_modules = linear_module_names\n    return peft_config\n\npeft.tuners.tuners_utils._maybe_include_all_linear_layers = _maybe_include_all_linear_layers\n\n\n@dataclass\nclass CustomHQQConfig:\n    nbits: int = 4\n    group_size: int = 64\n    view_as_float: bool = False\n    axis: int = 0\n    dynamic_config: dict[str, Any] = field(default_factory=dict)\n    skip_modules: list[str] = field(default_factory=lambda: ['lm_head'])\n    compute_dtype: str = 'float32'\n\n    def __post_init__(self):\n        self.compute_dtype = DTYPE_MAP[self.compute_dtype]\n\n    def use_aten(self):\n        return self.axis == 0 and all(d.get('axis', self.axis) == 0 for d in self.dynamic_config.values())\n\n    def get_dict(self, full_name):\n        \"\"\"Get final config dict to use for quantization, for module with full_name.\"\"\"\n        kwargs = asdict(self)\n        kwargs.pop('compute_dtype')\n        kwargs.pop('skip_modules')\n        dynamic_config = kwargs.pop('dynamic_config')\n        for key, value in dynamic_config.items():\n            if key in full_name:\n                kwargs.update(value)\n                break\n        return hqq_quantize.BaseQuantizeConfig(**kwargs)\n"
  },
  {
    "path": "utils/saver.py",
    "content": "import glob\nimport json\nimport os\nimport shutil\nimport sys\nimport time\nfrom pathlib import Path\n\nimport deepspeed\nimport torch\nimport transformers\nfrom safetensors.torch import save_file\n\nfrom utils.utils import DTYPE_MAP, is_main_process\n\n\nlast_checkpoint_time = None\n\n\ndef need_to_checkpoint(config):\n    if 'checkpoint_every_n_minutes' not in config:\n        return False\n    global last_checkpoint_time\n    checkpoint = False\n    # rank 0 tracks if we need to checkpoint, broadcasts to everyone else\n    if is_main_process():\n        current_time = time.time()\n        if last_checkpoint_time is None:\n            last_checkpoint_time = current_time\n        elif (current_time - last_checkpoint_time) / 60 > config['checkpoint_every_n_minutes']:\n            checkpoint = True\n            last_checkpoint_time = current_time\n    result = [checkpoint]\n    torch.distributed.broadcast_object_list(result, src=0)\n    return result[0]\n\n\ndef convert_state_dict_dtype(state_dict, dtype):\n    for key, v in state_dict.items():\n        state_dict[key] = v.to(device='cpu', dtype=DTYPE_MAP[dtype])\n\n\nclass Saver:\n    def __init__(self, model_engine, pipeline_model, train_dataloader, lora_config, save_root, args, config):\n        self.model_engine = model_engine\n        self.pipeline_model = pipeline_model\n        self.train_dataloader = train_dataloader\n        self.lora_config = lora_config\n        self.save_root = save_root + '/' if save_root[-1] != '/' else save_root\n        self.args = args\n        self.config = config\n        self.keep_states = config.get('keep_states', -1)\n        self.checkpoint_on_save = config.get('checkpoint_on_save', False)\n        self.chrono_states = {\n            'step': [],\n            'global_step': [],\n        }\n\n        # Load best loss from disk, if found, and if a best_loss model dir exists\n        self.best_loss = None\n        best_loss_path = os.path.join(self.save_root, 'best_loss.txt')\n        if os.path.exists(best_loss_path) and os.path.isdir(os.path.join(self.save_root, 'best_loss')):\n            with open(best_loss_path) as f:\n                self.best_loss = float(f.read())\n            print(f'Loaded best loss from disk: {self.best_loss}')\n\n    # TODO: this is pretty hacky. Is there a way to get the state_dict from the lora model directly,\n    # but still know which layers the given pipeline parallel stage actually trained?\n    def save_lora(self, name):\n        dp_id = self.model_engine.grid.get_data_parallel_rank()\n        stage_id = self.model_engine.grid.get_pipe_parallel_rank()\n        save_dir = self.save_root + name\n        tmp_dir = os.path.join(save_dir, 'tmp')\n        if dp_id == 0 and stage_id == 0:\n            os.makedirs(tmp_dir, exist_ok=False)\n        deepspeed.comm.barrier()\n        if dp_id == 0:\n            partial_state_dict = {}\n            for name, p in self.pipeline_model.named_parameters():\n                if p.requires_grad:\n                    if not hasattr(p, 'original_name'):\n                        print(\n                            f'WARNING: parameter {name} requires_grad but does not have original_name. Not saving it.'\n                        )\n                        continue\n                    partial_state_dict[p.original_name.replace('.default', '').replace('.modules_to_save', '')] = (\n                        p.detach()\n                    )\n                    if 'save_dtype' in self.config:\n                        convert_state_dict_dtype(partial_state_dict, self.config['save_dtype'])\n            torch.save(partial_state_dict, os.path.join(tmp_dir, f'state_dict_{stage_id}.bin'))\n        deepspeed.comm.barrier()\n        if dp_id == 0 and stage_id == 0:\n            state_dict = {}\n            for path in glob.glob(os.path.join(tmp_dir, '*.bin')):\n                state_dict.update(torch.load(path, map_location='cpu'))\n            torch.save(state_dict, os.path.join(save_dir, 'adapter_model.bin'))\n            self.lora_config.save_pretrained(save_dir)\n            shutil.copy(self.args.config, save_dir)\n            if hasattr(self.args, 'deepspeed_config') and self.args.deepspeed_config is not None:\n                shutil.copy(self.args.deepspeed_config, save_dir)\n            self.safe_rmtree(tmp_dir)\n\n\n    def save_full_model(self, name, max_shard_size='5GB'):\n        dp_id = self.model_engine.grid.get_data_parallel_rank()\n        stage_id = self.model_engine.grid.get_pipe_parallel_rank()\n        save_dir = self.save_root + name\n        tmp_dir = os.path.join(save_dir, 'tmp')\n        if dp_id == 0 and stage_id == 0:\n            os.makedirs(tmp_dir, exist_ok=False)\n        deepspeed.comm.barrier()\n        if dp_id == 0:\n            # With BF16_Optimizer, we get pickle errors unless we do p.detach(). I have no idea why.\n            partial_state_dict = {p.original_name: p.detach() for p in self.pipeline_model.parameters()}\n            if 'save_dtype' in self.config:\n                convert_state_dict_dtype(partial_state_dict, self.config['save_dtype'])\n            torch.save(partial_state_dict, os.path.join(tmp_dir, f'state_dict_{stage_id}.bin'))\n        deepspeed.comm.barrier()\n        if dp_id == 0 and stage_id == 0:\n            state_dict = {}\n            for path in glob.glob(os.path.join(tmp_dir, '*.bin')):\n                state_dict.update(torch.load(path, map_location='cpu'))\n            shards, index = transformers.modeling_utils.shard_checkpoint(\n                state_dict, max_shard_size=max_shard_size, weights_name='model.safetensors'\n            )\n            for shard_file, shard in shards.items():\n                save_file(shard, os.path.join(save_dir, shard_file), metadata={'format': 'pt'})\n            if index is not None:\n                save_index_file = 'model.safetensors.index.json'\n                save_index_file = os.path.join(save_dir, save_index_file)\n                # Save the index as well\n                with open(save_index_file, 'w', encoding='utf-8') as f:\n                    content = json.dumps(index, indent=2, sort_keys=True) + '\\n'\n                    f.write(content)\n            shutil.copy(self.args.config, save_dir)\n            if hasattr(self.args, 'deepspeed_config') and self.args.deepspeed_config is not None:\n                shutil.copy(self.args.deepspeed_config, save_dir)\n            additional_files_to_copy = [\n                'added_tokens.json',\n                'config.json',\n                'generation_config.json',\n                'special_tokens_map.json',\n                'tokenizer.json',\n                'tokenizer_config.json',\n                'tokenizer.model',\n            ]\n            for path in glob.glob(os.path.join(self.config['model'], '*')):\n                if os.path.basename(path) in additional_files_to_copy:\n                    shutil.copy(path, save_dir)\n            self.safe_rmtree(tmp_dir)\n\n    def will_save(self, type, name):\n        if self.keep_states <= 0 or not is_main_process():\n            return\n        if type == 'step':\n            self.chrono_states['step'].append(name)\n            if len(self.chrono_states['step']) > self.keep_states:\n                print(f'Deleting {self.chrono_states[\"step\"][0]}')\n                self.safe_rmtree(os.path.join(self.save_root, self.chrono_states['step'].pop(0)))\n        elif type == 'global_step':\n            self.chrono_states['global_step'].append(name)\n            if len(self.chrono_states['global_step']) > self.keep_states:\n                print(f'Deleting {self.chrono_states[\"global_step\"][0]}')\n                self.safe_rmtree(os.path.join(self.save_root, self.chrono_states['global_step'].pop(0)))\n        else:\n            raise ValueError(f'Unknown save type: {type}')\n\n    def save_model(self, name):\n        # ignore epoch saves for chrono_states\n        if name.startswith('step'):\n            self.will_save('step', name)\n        self.save_full_model(name) if self.lora_config is None else self.save_lora(name)\n\n    def save_checkpoint(self, step):\n        self.will_save('global_step', f'global_step{step}')\n        self.model_engine.save_checkpoint(\n            self.save_root,\n            client_state={\n                'step': step,\n                'custom_loader': self.train_dataloader.state_dict(),\n            },\n            save_latest=True,\n            exclude_frozen_parameters=True,\n        )\n\n    def process_epoch(self, epoch, step):\n        save_every_n_epochs = self.config.get('save_every_n_epochs', 1)\n        save_checkpoint_on_epoch_end = self.config.get('save_checkpoint_on_epoch_end', True)\n        if self.train_dataloader.epoch != epoch:\n            if save_checkpoint_on_epoch_end:\n                self.save_checkpoint(step)\n            if epoch % save_every_n_epochs == 0:\n                self.save_model(f'epoch{epoch}')\n            epoch = self.train_dataloader.epoch\n            if epoch > self.config['epochs']:\n                return None\n            if is_main_process():\n                print(f'Started new epoch: {epoch}')\n        return epoch\n\n    def process_step(self, step):\n        # Look at some simple \"signal files\" the user can write to save and optionally quit manually\n        should_manually_save = False\n        should_manually_quit = False\n        save_signal_file = Path(self.save_root) / 'save'\n        save_quit_signal_file = Path(self.save_root) / 'save_quit'\n        if save_signal_file.exists() and save_signal_file.is_file():\n            should_manually_save = True\n            deepspeed.comm.barrier()\n            if is_main_process():\n                os.remove(save_signal_file)\n        elif save_quit_signal_file.exists() and save_quit_signal_file.is_file():\n            should_manually_save = True\n            should_manually_quit = True\n            deepspeed.comm.barrier()\n            if is_main_process():\n                os.remove(save_quit_signal_file)\n\n        if ('save_steps' in self.config and step % self.config['save_steps'] == 0) or should_manually_save:\n            self.save_model(f'step{step}')\n            if self.checkpoint_on_save and not should_manually_save:\n                self.save_checkpoint(step)\n\n        pending_save_best_loss = os.path.exists(os.path.join(self.save_root, '.pending_save_best_loss'))\n        if pending_save_best_loss:\n            self.save_model('best_loss')\n            if is_main_process():\n                if self.old_best is not None:\n                    print(\n                        f'New best evaluation loss: {self.best_loss:.4f} from {self.old_best:.4f} (Δ{self.old_best - self.best_loss:.5f} [{100 * (1 - self.best_loss / self.old_best):.2f}%])'\n                    )\n                else:\n                    print(f'New best evaluation loss: {self.best_loss:.4f}')\n                os.replace(\n                    os.path.join(self.save_root, '.pending_save_best_loss'),\n                    os.path.join(self.save_root, 'best_loss.txt'),\n                )\n\n        if (not self.checkpoint_on_save and need_to_checkpoint(self.config)) or should_manually_save:\n            self.save_checkpoint(step)\n\n        if should_manually_quit:\n            print('Manually quitting')\n            sys.exit()\n\n    def append_eval_results(self, loss, save_best=True):\n        if loss is not None:\n            if self.best_loss is None:\n                print(f'Evaluation loss: {loss:.4f}')\n            elif loss >= self.best_loss:\n                print(\n                    f'Evaluation loss: {loss:.4f} (best: {self.best_loss:.4f}, Δ: {self.best_loss - loss:.5f} [{100 * (1 - loss / self.best_loss):.2f}%])'\n                )\n            if self.best_loss is None or loss < self.best_loss:\n                self.old_best = self.best_loss\n                self.best_loss = loss\n                if save_best:\n                    with open(os.path.join(self.save_root, '.pending_save_best_loss'), 'w') as f:\n                        f.write(str(self.best_loss))\n        deepspeed.comm.barrier()\n\n\n    # Attempt to remove a directory tree using exponential backoff for retries (default max wait = 31s)\n    def safe_rmtree(self, dir_path, max_retries=5, initial_wait_seconds=1):\n        for attempt in range(max_retries + 1):\n            try:\n                shutil.rmtree(dir_path)\n                return\n            except OSError as e:\n                if attempt == max_retries:\n                    raise e\n                time.sleep(initial_wait_seconds * 2**attempt)\n"
  },
  {
    "path": "utils/unsloth_utils.py",
    "content": "# Unsloth Zoo - Utilities for Unsloth\n# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# This program is free software: you can redistribute it and/or modify\n# it under the terms of the GNU Lesser General Public License as published by\n# the Free Software Foundation, either version 3 of the License, or\n# (at your option) any later version.\n#\n# This program is distributed in the hope that it will be useful,\n# but WITHOUT ANY WARRANTY; without even the implied warranty of\n# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n# GNU General Public License for more details.\n#\n# You should have received a copy of the GNU Lesser General Public License\n# along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\n# I (tdrussell) made a few modifications.\n\nimport torch\nfrom deepspeed.runtime.activation_checkpointing.checkpointing import detach_variable\n\n\nclass Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function):\n    \"\"\"\n    Code licensed under LGPL\n    Saves VRAM by smartly offloading to RAM.\n    Tiny hit to performance, since we mask the movement via non blocking calls.\n    \"\"\"\n\n    @staticmethod\n    @torch.amp.custom_fwd(device_type='cuda')\n    def forward(ctx, forward_function, hidden_states, *args):\n        saved_hidden_states = hidden_states.to('cpu', non_blocking=True)\n        with torch.no_grad():\n            output = forward_function(hidden_states, *args)\n        ctx.save_for_backward(saved_hidden_states)\n        ctx.forward_function = forward_function\n        ctx.args = args\n        return output\n\n    pass\n\n    @staticmethod\n    @torch.amp.custom_bwd(device_type='cuda')\n    def backward(ctx, *grads):\n        (hidden_states,) = ctx.saved_tensors\n        hidden_states = hidden_states.to('cuda', non_blocking=True).detach()\n        hidden_states.requires_grad_(True)\n        args = detach_variable(ctx.args)\n        inputs = (hidden_states,) + args\n        with torch.enable_grad():\n            outputs = ctx.forward_function(*inputs)\n\n        output_tensors = []\n        grad_tensors = []\n        for out, grad in zip(outputs, grads):\n            if out.requires_grad:\n                output_tensors.append(out)\n                grad_tensors.append(grad)\n        torch.autograd.backward(output_tensors, grad_tensors)\n        return (None,) + tuple(input.grad for input in inputs)\n\n    pass\n\n\npass\n\n\n@torch._disable_dynamo\ndef unsloth_checkpoint(function, *args):\n    return Unsloth_Offloaded_Gradient_Checkpointer.apply(function, *args)\n"
  },
  {
    "path": "utils/utils.py",
    "content": "import os.path\nimport sys\nfrom datetime import datetime\n\nimport torch\n\n\nsys.path.insert(0, os.path.abspath('axolotl/src'))\n\nfrom axolotl.utils.distributed import is_main_process, zero_first  # type: ignore # noqa\n\nDTYPE_MAP = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16}\n\n\n# Simplified logger-like printer.\ndef log(msg):\n    print(f'[{datetime.now().strftime(\"%Y-%m-%d %H:%M:%S.%f\")[:-3]}] [INFO] [qlora-pipe] {msg}')\n\n\ndef eta_str(eta):\n    eta = int(eta)\n    if eta > 3600:\n        return f'{eta // 3600}h{(eta % 3600) // 60}m'\n    return f'{eta // 60}m{eta % 60}s' if eta > 60 else f'{eta}s'\n"
  }
]