Repository: uclaml/SPPO Branch: main Commit: 5e61c4e90822 Files: 29 Total size: 160.3 KB Directory structure: gitextract_ou4elal9/ ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── models_configs/ │ ├── Mistral7B-PairRM-SPPO-Iter3/ │ │ ├── configs.yaml │ │ └── prompts.txt │ └── README.md ├── run_sppo_gemma-2-27b.sh ├── run_sppo_gemma-2.sh ├── run_sppo_llama-3.sh ├── run_sppo_mistral.sh ├── scripts/ │ ├── combine_generate.py │ ├── compute_prob.py │ ├── generate.py │ ├── generate.sh │ ├── pipeline.sh │ ├── preload.py │ ├── rank.py │ └── update_dataset.py ├── setup.cfg ├── setup.py └── sppo/ ├── alignment/ │ ├── __init__.py │ ├── configs.py │ ├── data.py │ ├── model_utils.py │ └── release.py ├── run_sft.py ├── run_sppo.py └── trainer.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ # Temp folders lm_cache/ logs/ checkpoints/ SPPO/ traindata/ output/ results/ wandb/ tinybench/ clean.sh synthetic_data_iter0/test.parquet mistral.parquet synthetic_data_iter0/train.parquet synthetic_ultra/test_prefs-00000-of-00001.parquet synthetic_ultra/train_prefs-00000-of-00001.parquet pairRM_eval/backlog checkpoints pairRM_eval/data pairRM_eval/figs dpo-borda-iter0 run_records.csv mistral-dpo-it-1_generated mistral.parquet synthetic_data_dpo-borda-iter1_borda/train.parquet synthetic_data_dpo-borda-iter1_chosen/train.parquet synthetic_data_dpo-borda-iter1_random/train.parquet synthetic_data_dpo-borda-iter1_wrt/train.parquet synthetic_data_iter0_borda/train.parquet synthetic_data_rpo-iter1_borda/train.parquet synthetic_data_rpo-iter1_chosen/train.parquet synthetic_data_rpo-iter1_random/train.parquet synthetic_data_rpo-iter1_wrt/train.parquet synthetic_ultra/train_prefs-00000-of-00001.parquet tinybench traindata/train.parquet synthetic_data_dpo-borda-iter1_borda/test.parquet synthetic_data_dpo-borda-iter1_borda/train.parquet synthetic_data_data-DPO-iter1-Borda-20k-Ours-311_borda/test.parquet synthetic_data_data-DPO-iter1-Borda-20k-Ours-311_borda/train.parquet synthetic_data_data-DPO-iter1-Borda-20k-Ours-311_chosen/test.parquet synthetic_data_data-DPO-iter1-Borda-20k-Ours-311_chosen/train.parquet synthetic_data_data-DPO-iter1-Borda-20k-Ours-311_random/test.parquet synthetic_data_data-DPO-iter1-Borda-20k-Ours-311_random/train.parquet synthetic_data_data-DPO-iter1-Borda-20k-Ours-311_wrt/test.parquet synthetic_data_data-DPO-iter1-Borda-20k-Ours-311_wrt/train.parquet synthetic_data_data-DPO-iter1-Borda-20k-Ours-622_borda/test.parquet synthetic_data_data-DPO-iter1-Borda-20k-Ours-622_borda/train.parquet synthetic_data_data-DPO-iter1-Borda-20k-Ours-622_chosen/test.parquet synthetic_data_data-DPO-iter1-Borda-20k-Ours-622_chosen/train.parquet synthetic_data_data-DPO-iter1-Borda-20k-Ours-622_random/test.parquet synthetic_data_data-DPO-iter1-Borda-20k-Ours-622_random/train.parquet synthetic_data_data-DPO-iter1-Borda-20k-Ours-622_wrt/test.parquet synthetic_data_data-DPO-iter1-Borda-20k-Ours-622_wrt/train.parquet synthetic_data_data-DPO-iter2-Borda-20k-Ours-1276_borda/test.parquet synthetic_data_data-DPO-iter2-Borda-20k-Ours-1276_borda/train.parquet synthetic_data_data-DPO-iter2-Borda-20k-Ours-1276_chosen/test.parquet synthetic_data_data-DPO-iter2-Borda-20k-Ours-1276_chosen/train.parquet synthetic_data_data-DPO-iter2-Borda-20k-Ours-1276_random/test.parquet synthetic_data_data-DPO-iter2-Borda-20k-Ours-1276_random/train.parquet synthetic_data_data-DPO-iter2-Borda-20k-Ours-1276_wrt/test.parquet synthetic_data_data-DPO-iter2-Borda-20k-Ours-1276_wrt/train.parquet synthetic_data_data-RPO-iter1-Borda-20k-Ours-311_borda/test.parquet synthetic_data_data-RPO-iter1-Borda-20k-Ours-311_borda/train.parquet synthetic_data_data-RPO-iter1-Borda-20k-Ours-311_chosen/test.parquet synthetic_data_data-RPO-iter1-Borda-20k-Ours-311_chosen/train.parquet synthetic_data_data-RPO-iter1-Borda-20k-Ours-311_random/test.parquet synthetic_data_data-RPO-iter1-Borda-20k-Ours-311_random/train.parquet synthetic_data_data-RPO-iter1-Borda-20k-Ours-311_wrt/test.parquet synthetic_data_data-RPO-iter1-Borda-20k-Ours-311_wrt/train.parquet synthetic_data_data-RPO-iter1-Borda-20k-Ours-622_borda/test.parquet synthetic_data_data-RPO-iter1-Borda-20k-Ours-622_borda/train.parquet synthetic_data_data-RPO-iter1-Borda-20k-Ours-622_chosen/test.parquet synthetic_data_data-RPO-iter1-Borda-20k-Ours-622_chosen/train.parquet synthetic_data_data-RPO-iter1-Borda-20k-Ours-622_random/test.parquet synthetic_data_data-RPO-iter1-Borda-20k-Ours-622_random/train.parquet synthetic_data_data-RPO-iter1-Borda-20k-Ours-622_wrt/test.parquet synthetic_data_data-RPO-iter1-Borda-20k-Ours-622_wrt/train.parquet synthetic_data_data-RPO-Iter1-New-309_score/test.parquet synthetic_data_data-RPO-Iter1-New-309_score/train.parquet synthetic_data_data-RPO-iter2-Borda-20k-Ours-319_borda/test.parquet synthetic_data_data-RPO-iter2-Borda-20k-Ours-319_borda/train.parquet synthetic_data_data-RPO-iter2-Borda-20k-Ours-319_chosen/test.parquet synthetic_data_data-RPO-iter2-Borda-20k-Ours-319_chosen/train.parquet synthetic_data_data-RPO-iter2-Borda-20k-Ours-319_random/test.parquet synthetic_data_data-RPO-iter2-Borda-20k-Ours-319_random/train.parquet synthetic_data_data-RPO-iter2-Borda-20k-Ours-319_wrt/test.parquet synthetic_data_data-RPO-iter2-Borda-20k-Ours-319_wrt/train.parquet synthetic_data_dpo-borda-iter1_chosen/test.parquet synthetic_data_dpo-borda-iter1_chosen/train.parquet synthetic_data_dpo-borda-iter1_random/test.parquet synthetic_data_dpo-borda-iter1_random/train.parquet synthetic_data_iter0_borda/test.parquet synthetic_data_iter0_borda/train.parquet synthetic_data_mistral-3way_borda/test.parquet synthetic_data_mistral-3way_borda/train.parquet synthetic_data_mistral-3way_chosen/test.parquet synthetic_data_mistral-3way_chosen/train.parquet synthetic_data_mistral-3way_random/test.parquet synthetic_data_mistral-3way_random/train.parquet synthetic_data_mistral-3way_wrt/test.parquet synthetic_data_mistral-3way_wrt/train.parquet synthetic_data_rpo-iter1_borda/test.parquet synthetic_data_rpo-iter1_borda/train.parquet synthetic_data_rpo-iter1_chosen/test.parquet synthetic_data_rpo-iter1_chosen/train.parquet synthetic_data_rpo-iter1_random/test.parquet synthetic_data_rpo-iter1_random/train.parquet recipes/zephyr-7b-beta/dpo/config_full_snorkel_iter2_generated.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_data-RPO-Iter1-New-309_score.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_dpo-borda-iter1_borda.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_dpo-borda-iter1_chosen.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_dpo-borda-iter1_random.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_dpo-borda-iter1_wrt.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_mistral-3way_borda.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_rpo-iter1_borda.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_rpo-iter1_chosen.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_rpo-iter1_random.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_rpo-iter1_wrt.yaml recipes/zephyr-7b-beta/dpo/config_full_snorkel_iter2_generated.yaml recipes/zephyr-7b-beta/dpo/config_full_snorkel_iter3_generated.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_dpo-borda-iter1_borda.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_dpo-borda-iter1_chosen.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_dpo-borda-iter1_random.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_dpo-borda-iter1_wrt.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_mistral-3way_score.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_mistral-7b-instruct-v0.2-iter-0_part1_borda.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_rpo-iter1_borda.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_rpo-iter1_chosen.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_rpo-iter1_random.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_rpo-iter1_wrt.yaml recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_RPO-Iter2-New-312_score.yaml synthetic_data_mistral-3way_score/test.parquet synthetic_data_mistral-3way_score/train.parquet synthetic_data_RPO-Iter2-New-312_score/test.parquet synthetic_data_RPO-Iter2-New-312_score/train.parquet synthetic_data_WPO-Iter1-1236_score/test.parquet synthetic_data_WPO-Iter1-1236_score/train.parquet recipes/zephyr-7b-beta/dpo/config_full_synthetic_data_WPO-Iter1-1236_score.yaml *_log_*.txt synthetic_data_* generated/ ranking/ checkpoints-*/ recipes/ push_model_to_hub.py scripts/test.sh .pre.commit-config.yaml ================================================ FILE: .pre-commit-config.yaml ================================================ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.2.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # SPPO: Self-Play Preference Optimization for Language Model Alignment ![Mistral-7B-Instruct](https://img.shields.io/badge/Model-Mistral--7B--Instruct--v0.2-green) ![Llama-3-8B-Instruct](https://img.shields.io/badge/Model-Llama--3--8B--Instruct-green) ![AlpacaEval 2.0](https://img.shields.io/badge/Task-AlpacaEval_2.0-red ) ![Open LLM](https://img.shields.io/badge/Task-Open_LLM_Leaderboard-red) ![MT-Bench](https://img.shields.io/badge/Task-MT--Bench-red) This repository contains the official code and released models for the paper [Self-Play Preference Optimization for Language Model Alignment](https://arxiv.org/abs/2405.00675). Authors: [Yue Wu](https://yuewu.us/)\*, [Zhiqing Sun](https://www.cs.cmu.edu/~zhiqings/)\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Yiming Yang](https://www.cs.cmu.edu/~yiming/), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) [[Webpage](https://uclaml.github.io/SPPO/)] [[Huggingface](https://huggingface.co/papers/2405.00675)] [[Paper](https://arxiv.org/abs/2405.00675)] ## 🔔 News - **[01/22/2025]** SPPO has been accepted by ICLR2025! - **[06/29/2024]** We released [Gemma-2-9B-It-SPPO-Iter3](https://huggingface.co/UCLA-AGI/Gemma-2-9B-It-SPPO-Iter3) trained upon [gemma-2-9b-it](https://huggingface.co/google/gemma-2-9b-it), AlpacaEval 2.0 LC-win rate reached 53.27. - **[06/25/2024]** Our code is open-sourced! - **[05/01/2024]** Our paper is released on arXiv: https://arxiv.org/abs/2405.00675. ## Table of Content - [About SPPO](#about-sppo) - [Released Models](#released-models) - [Environment Setup](#environment-setup) - [Training Scripts](#training-scripts) - [Evaluation](#evaluation) - [Troubleshoot](#troubleshoot) - [Citation](#citation) - [Acknowledgements](#acknowledgements) ## About SPPO We propose a new self-play framework dubbed SPPO for language model alignment and a new learning objective (called SPPO loss) derived from the self-play framework to fine-tune large language models efficiently.


AlpacaEval 2.0 leaderboard results of normal and length-controlled (LC) win rates in percentage (\%). Mistral-7B-SPPO can outperform larger models and Mistral-7B-SPPO (best-of-16) can outperform proprietary models such as GPT-4(6/13). Llama-3-8B-SPPO exhibits even better performance.

SPPO can significantly enhance the performance of an LLM without strong external signals such as responses or preferences from GPT-4. It can outperform the model trained with iterative direct preference optimization (DPO), among other methods. SPPO is theoretically grounded, ensuring that the LLM can converge to the von Neumann winner (i.e., Nash equilibrium) under general, potentially intransitive preference, and empirically validated through extensive evaluations on multiple datasets. For more details, you can check our paper [here](https://arxiv.org/abs/2405.00675). ## Base Models and Released Models | Model | AlpacaEval2.0 LC Win Rate | AlpacaEval2.0 Win Rate | | :--- | :---: | :---: | |🤗[Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) | 17.11 | 14.72 | |🤗[Mistral-7B-SPPO Iter1](https://huggingface.co/UCLA-AGI/Mistral7B-PairRM-SPPO-Iter1) |24.79 | 23.51| |🤗[Mistral-7B-SPPO Iter2](https://huggingface.co/UCLA-AGI/Mistral7B-PairRM-SPPO-Iter2) |26.89 |27.62 | |🤗[Mistral-7B-SPPO Iter3](https://huggingface.co/UCLA-AGI/Mistral7B-PairRM-SPPO-Iter3) |28.53 |31.02| |🤗[Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |22.92 |22.57 | |🤗[Llama-3-8B-SPPO Iter1](https://huggingface.co/UCLA-AGI/Llama-3-Instruct-8B-SPPO-Iter1) |31.73 |31.74 | |🤗[Llama-3-8B-SPPO Iter2](https://huggingface.co/UCLA-AGI/Llama-3-Instruct-8B-SPPO-Iter2) |35.15 |35.98 | |🤗[Llama-3-8B-SPPO Iter3](https://huggingface.co/UCLA-AGI/Llama-3-Instruct-8B-SPPO-Iter3) |38.77 |39.85 | |🤗[Gemma-2-9B-It](https://huggingface.co/google/gemma-2-9b-it) |45.08 |35.62 | |🤗[Gemma-2-9B-SPPO Iter1](https://huggingface.co/UCLA-AGI/Gemma-2-9B-It-SPPO-Iter1) |48.70 |40.76 | |🤗[Gemma-2-9B-SPPO Iter2](https://huggingface.co/UCLA-AGI/Gemma-2-9B-It-SPPO-Iter2) |50.93 | 44.64 | |🤗[Gemma-2-9B-SPPO Iter3](https://huggingface.co/UCLA-AGI/Gemma-2-9B-It-SPPO-Iter3) |**53.27** |**47.74** | ## Environment Setup Our training code is based on the alignment-handbook codebase. We utilize `vllm` for generation and `pairRM` for ranking. Follow the steps below to set up your environment: 1. **Create a Virtual Environment:** ```bash conda create -n sppo python=3.10 conda activate sppo ``` 2. **Install vllm for Generation:** ```bash pip install vllm ``` 3. **Install PairRM:** ```bash git clone https://github.com/yuchenlin/LLM-Blender.git cd LLM-Blender pip install -e . ``` 4. **Download and Install Training Dependencies:** ```bash git clone https://github.com/uclaml/SPPO.git cd SPPO pip install -e . ``` ## Training Scripts Execute the training scripts based on the base model you choose: - For **Mistral-7B-Instruct-v0.2**: ```bash bash run_sppo_mistral.sh ``` - For **Llama-3-8B-Instruct**: ```bash bash run_sppo_llama-3.sh ``` These scripts manage the training iterations, generation, and PairRM ranking processes. Note that some scripts may attempt to push datasets to the Hugging Face Hub under the UCLA-AGI organization. Ensure you have write access, or modify the organization name accordingly, or comment out any `push_to_hub` commands if necessary. Detailed scripts for each component are listed as follows: ### Breakdown of Scripts: 1. **Generation:** ```bash python scripts/generate.py --model $MODEL --maxlen 2048 --output_dir $OUTPUT_DIR --prompts $PROMPTS ``` Main parameters: - `model`: Specifies the model used for generation. In the first iteration, the model should be either `mistralai/Mistral-7B-Instruct-v0.2` or `meta-llama/Meta-Llama-3-8B-Instruct`. - `maxlen`: Sets the token length for generation, defining the maximum number of tokens generated. - `pairs`: Determines the number of generated samples per prompt, with a default setting of 5. Please note that changing this number is not supported by the overall pipeline. - `output_dir`: Specifies the directory paths for saving intermediate results. - `prompts`: Defines the set of prompts used for generation. - `frac_len`: Enables the operation of vllm on multiple GPUs by dividing prompts into different fractions. `frac_len` defines the number of prompts in each fraction. For usage examples, see `generate.sh`. - `data_frac`: Used in conjunction with `frac_len` for multi-GPU setups, `data_frac` indicates which fraction of the data the current GPU is processing. Refer to `generate.sh` for more details. 2. **Ranking:** ```bash python scripts/rank.py --output_dir $OUTPUT_DIR --prompts $PROMPTS ``` Main Parameters: - `output_dir`: Specifies the directory paths where intermediate results are saved. Note that the default script attempts to push datasets to Hugging Face under the UCLA-AGI organization. You may need to adjust this to your organization, obtain write access for UCLA-AGI, or disable the `push_to_hub` command if necessary. - `pairs`: Sets the number of generated samples per prompt, with a default of 5. Please note that other numbers are not supported by the overall pipeline. - `frac_len`: This parameter is used to enable the use of PairRM on multiple GPUs by dividing prompts into different fractions. `frac_len` determines the number of prompts in each fraction. For usage examples, refer to `generate.sh`. - `data_frac`: Similar to `frac_len`, this option is used for running PairRM on multiple GPUs. It specifies which fraction of the data the current GPU is processing. See `generate.sh` for examples. - `prompts`: Defines the set of prompts used for generation. - `gpu`: Indicates the GPU index used for ranking; it should match the `data_frac` parameter. 3. **Training:** ```bash bash scripts/pipeline.sh --model $MODEL --iter $ITER --dataset $DATASET --output_dir $OUTPUT_DIR --num 1 ``` Main Parameters: - model: The base model for training. - dataset: The dataset used for training. - output_dir: The name of the output model. - num: The number of training epochs. ## Evaluation We adhere to the established guidelines for evaluation and utilize the following repositories: - [AlpacaEval 2](https://github.com/tatsu-lab/alpaca_eval) - [MT-Bench](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge) - [HuggingFace Open LLM Leaderboard](https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard) We provide the model configurations used during AlpacaEval 2 in the `models_configs` directory. Please note that after the initial release of our model, we retrained it using a slightly modified prompt. The win rates observed post-retraining are comparable to the original results. ## Troubleshoot For questions related to the paper, please contact the authors via email. If you encounter any issues with the code or wish to report a bug, feel free to open an issue on our GitHub repository. ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=uclaml/SPPO&type=Date)](https://star-history.com/#uclaml/SPPO&Date) ## Citation ``` @article{wu2024self, title={Self-play preference optimization for language model alignment}, author={Wu, Yue and Sun, Zhiqing and Yuan, Huizhuo and Ji, Kaixuan and Yang, Yiming and Gu, Quanquan}, year={2024} } ``` ## Acknowledgements We thank the authors of [The Alignment Handbook](https://github.com/huggingface/alignment-handbook) for their foundational contributions to the training code. We also acknowledge the use of [PairRM](https://github.com/yuchenlin/LLM-Blender) for ranking and [vllm](https://github.com/vllm-project/vllm) for generation. ================================================ FILE: models_configs/Mistral7B-PairRM-SPPO-Iter3/configs.yaml ================================================ Mistral7B-PairRM-SPPO-Iter3: # this should be the same as the name as the current directory prompt_template: "Mistral7B-PairRM-SPPO-Iter3/prompt.txt" fn_completions: "vllm_local_completions" completions_kwargs: model_name: "UCLA-AGI/Mistral7B-PairRM-SPPO-Iter3" model_kwargs: dtype: 'bfloat16' tokenizer_mode: "auto" trust_remote_code: True max_new_tokens: 2048 temperature: 0.7 top_p: 0.9 batch_size: 900 pretty_name: "Mistral7B-PairRM-SPPO-Iter3" # name in the leaderboard link: "https://huggingface.co/UCLA-AGI/Mistral7B-PairRM-SPPO-Iter3" ================================================ FILE: models_configs/Mistral7B-PairRM-SPPO-Iter3/prompts.txt ================================================ [INST] {instruction} [/INST] ================================================ FILE: models_configs/README.md ================================================ We have uploaded the evaluation configuration for AlpacaEval 2.0 for `Mistral7B-PairRM-SPPO-Iter3`. Please note that for models trained using the current training script, the prompts.txt file differs from the version we previously uploaded. ================================================ FILE: run_sppo_gemma-2-27b.sh ================================================ #!/bin/bash iter_num=3 for i in $(seq 1 $iter_num); do if [ "$i" -eq 1 ]; then MODEL="google/gemma-2-27b-it" else MODEL=$OUTPUT_DIR fi OUTPUT_DIR="checkpoints/Gemma-2-27B-SPPO-It-Iter${i}" PROMPT="UCLA-AGI/data-mistral-7b-instruct-sppo-iter${i}" OUT="data-gemma-2-27b-it-sppo-iter${i}" echo "runing epoch $i" DATASET="synthetic_data_gemma-2-27b-it-sppo-iter${i}_score" if [ ! -d "$DATASET" ]; then bash scripts/generate.sh --model $MODEL --prompt $PROMPT --out_path $OUT fi bash scripts/pipeline.sh --model $MODEL --iter $i --dataset "$DATASET" --output_dir $OUTPUT_DIR --num 1 --batch_size 1 --accumulate 8 done ================================================ FILE: run_sppo_gemma-2.sh ================================================ #!/bin/bash iter_num=3 for i in $(seq 1 $iter_num); do if [ "$i" -eq 1 ]; then MODEL="google/gemma-2-9b-it" else MODEL=$OUTPUT_DIR fi OUTPUT_DIR="checkpoints/Gemma-2-9B-SPPO-It-Iter${i}" PROMPT="UCLA-AGI/data-mistral-7b-instruct-sppo-iter${i}" OUT="data-gemma-2-9b-it-sppo-iter${i}" echo "runing epoch $i" bash scripts/generate.sh --model $MODEL --prompt $PROMPT --out_path $OUT bash scripts/pipeline.sh --model $MODEL --iter $i --dataset "synthetic_data_gemma-2-9b-it-sppo-iter${i}_score" --output_dir $OUTPUT_DIR --num 1 --batch_size 4 --accumulate 2 done ================================================ FILE: run_sppo_llama-3.sh ================================================ #!/bin/bash iter_num=3 for i in $(seq 1 $iter_num); do if [ "$i" -eq 1 ]; then MODEL="meta-llama/Meta-Llama-3-8B-Instruct" else MODEL=$OUTPUT_DIR fi OUTPUT_DIR="checkpoints/Llama-3-8B-Instruct-SPPO-Iter${i}" PROMPT="UCLA-AGI/data-mistral-7b-instruct-sppo-iter${i}" OUT="data-llama-3-8b-instruct-sppo-iter${i}" bash scripts/generate.sh --model $MODEL --prompt $PROMPT --out_path $OUT bash scripts/pipeline.sh --model $MODEL --iter $i --dataset "synthetic_data_llama-3-8b-instruct-sppo-iter${i}_score" --output_dir $OUTPUT_DIR --num 1 done ================================================ FILE: run_sppo_mistral.sh ================================================ #!/bin/bash iter_num=3 for i in $(seq 1 $iter_num); do echo "Running Iter ${i}" if [ "$i" -eq 1 ]; then MODEL="mistralai/Mistral-7B-Instruct-v0.2" else MODEL=$OUTPUT_DIR fi OUTPUT_DIR="checkpoints/Mistral-7B-Instruct-SPPO-Iter${i}" PROMPT="UCLA-AGI/data-mistral-7b-instruct-sppo-iter${i}" OUT="data-mistral-7b-instruct-sppo-iter${i}" bash scripts/generate.sh --model $MODEL --prompt $PROMPT --out_path $OUT bash scripts/pipeline.sh --model $MODEL --iter $i --dataset "synthetic_data_mistral-7b-instruct-sppo-iter${i}_score" --output_dir $OUTPUT_DIR --num 1 done ================================================ FILE: scripts/combine_generate.py ================================================ import json import pandas as pd import argparse def parse_arguments(): """Parse command line arguments.""" parser = argparse.ArgumentParser() parser.add_argument('--output_dir', type=str, default='generated/iter1') parser.add_argument("--pairs", type=int, default=5) parser.add_argument("--numgpu", type=int, default=8) parser.add_argument("--gpu_ids", type=str, default=None) return parser.parse_args() def main(): args = parse_arguments() for j in range(args.pairs): results = [] if args.gpu_ids is not None: gpus = args.gpu_ids.strip("()").split(',') else: gpus = range(args.numgpu) for i in gpus: file_path = f"{args.output_dir}/responses_{i}_{j}.json" print(f'Reading from {file_path}') with open(file_path) as f: gen = json.load(f) results += gen output_path = f"{args.output_dir}/responses_{j}.json" print(f'Saved to {output_path}') with open(output_path, "w") as f: json.dump(results, f) if __name__ == "__main__": main() ================================================ FILE: scripts/compute_prob.py ================================================ import numpy as np from datasets import load_dataset, Dataset import json import argparse import pandas as pd import datasets import os def parse_arguments(): """Parse command line arguments.""" parser = argparse.ArgumentParser() parser.add_argument("--output_dir", type=str, default="generated/iter1") parser.add_argument("--pairs", type=int, default=5) parser.add_argument("--prompts", type=str, default="UCLA-AGI/data-mistral-7b-instruct-sppo-iter1") parser.add_argument("--frac_len", type=int, default=0) parser.add_argument("--num_gpu", type=int, default=8) parser.add_argument("--org", type=str, default="UCLA-AGI") parser.add_argument("--gpu_ids", type=str, default=None) return parser.parse_args() def from_ranks(args): num_gpu = args.num_gpu pairs = args.pairs data = load_dataset(args.prompts, split="train") print(f"Length of dataset: {len(data)}") scores = [0 for _ in range(len(data))] if args.gpu_ids is not None: gpus = args.gpu_ids.strip("()").split(',') else: gpus = range(args.num_gpu) for data_frac, idx in enumerate(gpus): locals = np.load(f"ranking/{args.output_dir}/{idx}_{data_frac}.npy") locals = list(locals) for lidx, sc in enumerate(locals): scores[data_frac * args.frac_len + lidx] = sc probs = [] rm_scores = [] for idx, score in enumerate(scores): prb = np.zeros((pairs, pairs)) for i in range(pairs): for j in range(pairs): prb[i][j] = 1 / (1 + np.exp(score[j] - score[i])) prb = prb.tolist() probs.append(prb) rm_scores.append(score) print("Saving probabilities...") with open(f"generated/{args.output_dir}/probabilities.json", "w") as f: json.dump(probs, f) df = data.to_pandas() for i in range(pairs): with open(f"generated/{args.output_dir}/responses_{i}.json") as f: responses = json.load(f) fmt = [ [ {"content": data[j]["prompt"], "role": "user"}, {"content": responses[j], "role": "assistant"}, ] for j in range(len(data)) ] df[f"generate_{i}"] = fmt df["probability"] = probs df["rm_scores"] = rm_scores df.to_parquet(f"generated/{args.output_dir}/train.parquet") import numpy as np import os import pandas as pd import datasets def prepare_score(args): # Load dataset and convert to DataFrame train = datasets.load_dataset(f"generated/{args.output_dir}") train = pd.DataFrame(train['train']) # Calculate metrics and probabilities metrics = train['rm_scores'].apply(lambda x: np.array(x[-5:])) metrics_prob = train['probability'].apply(lambda x: np.stack(x).sum(axis=1)) maxmin = metrics.apply(lambda x: [x.argmax(), x.argmin()]) # Reorganize the DataFrame for easy access train_ordered = train[['generate_0', 'generate_1', 'generate_2', 'generate_3', 'generate_4', 'probability']] # Determine chosen and rejected items based on maxmin indices chosen = [train_ordered.iloc[i, maxmin[i][0]] for i in range(len(train_ordered))] rejected = [train_ordered.iloc[i, maxmin[i][1]] for i in range(len(train_ordered))] # Calculate probabilities for chosen and rejected items chosen_probs = [train_ordered['probability'].iloc[i][maxmin[i][0]][maxmin[i][1]] for i in range(len(train_ordered))] chosen_probs_win = [metrics_prob[i][maxmin[i][0]] / len(metrics_prob.iloc[0]) for i in range(len(metrics_prob))] chosen_probs_lose = [metrics_prob[i][maxmin[i][1]] / len(metrics_prob.iloc[0]) for i in range(len(metrics_prob))] # Create a new DataFrame with the results train_new = pd.DataFrame({ 'chosen': chosen, 'rejected': rejected, 'chosen_probs': chosen_probs, 'chosen_probs_win': chosen_probs_win, 'chosen_probs_lose': chosen_probs_lose }) # Determine output directory output_dir = '-'.join(args.output_dir.split('-')[1:]) OUTPATH = f'synthetic_data_{output_dir}_score' os.makedirs(OUTPATH, exist_ok=True) # Save train and test datasets to parquet files train_new.to_parquet(f'{OUTPATH}/train.parquet', index=False) print(f"Saved file to {OUTPATH}/train.parquet") # Temporary solution to make the code run, cannot use for test/evaluation purpose test = train_new.sample(n=500) test.to_parquet(f'{OUTPATH}/test.parquet', index=False) print(f"Saved file to {OUTPATH}/test.parquet") return OUTPATH def push_dataset(file_dir, org): data = Dataset.from_parquet(f"{file_dir}/train.parquet") try: test = Dataset.from_parquet(f"{file_dir}/test.parquet") except: train = pd.read_parquet(f"{file_dir}/train.parquet") # Temporary solution to make the code run, cannot use for test/evaluation purpose test = train.sample(n=500) test.to_parquet(f"{file_dir}/test.parquet", index=False) test = Dataset.from_parquet(f"{file_dir}/test.parquet") data.push_to_hub(f"{org}/{file_dir}", split="train", private=True) test.push_to_hub(f"{org}/{file_dir}", split="test", private=True) if __name__ == "__main__": args = parse_arguments() from_ranks(args) data = Dataset.from_parquet(f"generated/{args.output_dir}/train.parquet") data.push_to_hub(f"{args.org}/{args.output_dir}_generated", private=True) out_path = prepare_score(args) push_dataset(out_path, args.org) ================================================ FILE: scripts/generate.py ================================================ from transformers import AutoTokenizer from datasets import load_dataset from vllm import LLM, SamplingParams import argparse import torch import json import os from pathlib import Path import random import warnings import numpy as np warnings.filterwarnings("ignore") def set_seed(seed=5775709): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def parse_arguments(): """Parse command line arguments.""" parser = argparse.ArgumentParser() parser.add_argument( "--model", type=str, default="mistralai/Mistral-7B-Instruct-v0.2" ) parser.add_argument("--output_dir", type=str, default="generated/iter1") parser.add_argument("--prompts", type=str, default="UCLA-AGI/data-mistral-7b-instruct-sppo-iter1") parser.add_argument("--maxlen", type=int, default=2048) parser.add_argument("--pairs", type=int, default=5) parser.add_argument("--frac_len", type=int, default=0) parser.add_argument("--data_frac", type=int, default=0) parser.add_argument("--world_size", type=int, default=1) return parser.parse_args() def apply_template(text, tokenizer): return tokenizer.apply_chat_template( [{"role": "user", "content": text}, {"role": "assistant", "content": "None"}], tokenize=False, add_generate_prompt=True ).split("None")[0] def split_prompts(prompts, frac_len, data_frac): if frac_len > 0: split_len = frac_len if split_len * (data_frac + 1) > len(prompts): return prompts[split_len * data_frac:] else: return prompts[split_len * data_frac: split_len * (data_frac + 1)] else: return prompts[:] def main(): args = parse_arguments() model_path = args.model output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) data = load_dataset(args.prompts, split="train") if "mistral" in model_path.lower(): tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") elif "llama-3" in model_path.lower(): tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") elif "gemma-2" in model_path.lower(): tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it") else: raise ValueError("Model not supported") tokenizer.pad_token = tokenizer.eos_token # import pdb # pdb.set_trace() llm = LLM( model=model_path, tensor_parallel_size=args.world_size, ) prompts = [apply_template(data[idx]["prompt"], tokenizer) for idx in range(len(data))] print(prompts[0]) data_frac, frac_len = args.data_frac, args.frac_len prompts = split_prompts(prompts, frac_len, data_frac) pairs = args.pairs os.makedirs(args.output_dir, exist_ok=True) for p in range(pairs): set_seed(p * 50) sampling_params = SamplingParams( temperature=0.7, top_p=0.9, max_tokens=args.maxlen, seed=p * 50, ) response = llm.generate(prompts, sampling_params) output = list(map(lambda x: x.outputs[0].text, response)) with open(f"{args.output_dir}/responses_{data_frac}_{p}.json", "w") as f: json.dump(output, f) if __name__ == "__main__": main() ================================================ FILE: scripts/generate.sh ================================================ set -e set -x export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" AVAILABLE_GPUS=(0 1 2 3 4 5 6 7) HF_ORG=UCLA-AGI MODEL="mistralai/Mistral-7B-Instruct-v0.2" OUTDIR="data-mistral-7b-instruct-sppo-iter1" PAIRS=5 FRAC=0 PROMPTS="UCLA-AGI/data-mistral-7b-instruct-sppo-iter1" while [[ "$#" -gt 0 ]]; do case $1 in --pairs) PAIRS="$2" shift ;; --frac) FRAC="$2" shift ;; --model) MODEL="$2" shift ;; --out_path) OUTDIR="$2" shift ;; --prompt) PROMPTS="$2" shift ;; *) echo "Unknown parameter passed: $1" exit 1 ;; esac shift done ##################### # Generate Data ##################### #frac length 2600 * num_gpus 8 = 20800, should be larger than the length of the dataset. Change frac_len accordingly when dataset changes FRAC_LEN=$((20800 / ${#AVAILABLE_GPUS[@]})) echo "Using frac_len ${FRAC_LEN}" ( data_frac=0 for gpu_id in ${AVAILABLE_GPUS[@]}; do CUDA_VISIBLE_DEVICES=$gpu_id python3 scripts/generate.py --model $MODEL --maxlen 2048 --output_dir "generated/$OUTDIR" --prompts $PROMPTS --pairs $PAIRS --world_size 1 --frac_len $FRAC_LEN --data_frac $data_frac > output_log_${gpu_id}.txt 2>&1 & ((data_frac+=1)); done wait ) & all_gen=$! wait $all_gen python3 scripts/combine_generate.py --output_dir "generated/$OUTDIR" --gpu_ids "$(IFS=, ; echo "${AVAILABLE_GPUS[*]}")" --pairs $PAIRS # ##################### # # Rank Data # ##################### # # frac length 2600 * num_gpus 8 = 20800, should be larger than the length of the dataset. Change frac_len accordingly when dataset changes python3 scripts/preload.py ( data_frac=0 for gpu_id in ${AVAILABLE_GPUS[@]}; do CUDA_VISIBLE_DEVICES=$gpu_id python3 scripts/rank.py --model $MODEL --output_dir $OUTDIR --pairs $PAIRS --numgpu ${#AVAILABLE_GPUS[@]} --frac_len $FRAC_LEN --data_frac $data_frac --gpu $gpu_id --prompts $PROMPTS > rank_log_${gpu_id}.txt 2>&1 & ((data_frac+=1)); done wait ) & all_rank=$! wait $all_rank python3 scripts/compute_prob.py --org $HF_ORG --gpu_ids "$(IFS=, ; echo "${AVAILABLE_GPUS[*]}")" --output_dir $OUTDIR --pairs $PAIRS --frac_len $FRAC_LEN --prompts $PROMPTS ================================================ FILE: scripts/pipeline.sh ================================================ set -e set -x export OMP_NUM_THREADS=2 LEARNING_RATE="5.0e-7" ITER="1" BETA="0.001" LOSS_TYPE="sppo" OPTIM="rmsprop" PREF="sppo_score" NUM=18 MODEL="mistralai/Mistral-7B-Instruct-v0.2" DATASET="synthetic_data_mistral-7b-instruct-sppo-iter1_score" BATCH_SIZE=8 ACCUMULATE=1 while [[ "$#" -gt 0 ]]; do case $1 in --learning_rate) LEARNING_RATE="$2" shift ;; --beta) BETA="$2" shift ;; --optim) OPTIM="$2" shift ;; --output_dir) OUTPUT_DIR="$2" shift ;; --iter) ITER="$2" shift ;; --loss_type) LOSS_TYPE="$2" shift ;; --prefix) PREF="$2" shift ;; --model) MODEL="$2" shift ;; --dataset) DATASET="$2" shift ;; --num) NUM="$2" shift ;; --batch_size) BATCH_SIZE="$2" shift ;; --accumulate) ACCUMULATE="$2" shift ;; *) echo "Unknown parameter passed: $1" exit 1 ;; esac shift done PREF="${PREF}_${NUM}" LEVEL1="iter${ITER}_${LEARNING_RATE}_beta${BETA}_${OPTIM}" LEVEL2="${LOSS_TYPE}_${PREF}" #OUTPUT_DIR="checkpoints/${LEVEL1}/${LEVEL2}" log_file="iter${ITER}_${LEARNING_RATE}_${BETA}_${OPTIM}_${LOSS_TYPE}_${PREF}" dataset_name=$(echo "$DATASET" | cut -d '/' -f2) new_config_file="recipes/uclaml-sppo/config_full_${dataset_name}.yaml" # Copy the original configuration file to the new one cp recipes/uclaml-sppo/config_full.yaml "$new_config_file" python3 scripts/update_dataset.py --dataset $DATASET --config "$new_config_file" >"$log_file.log" echo "logging to $log_file.log" # --main_process_port ${port} \ ACCELERATE_LOG_LEVEL=info accelerate launch \ --config_file recipes/accelerate_configs/deepspeed_zero3.yaml \ --main_process_port 2930 \ sppo/run_sppo.py "$new_config_file" \ --learning_rate=$LEARNING_RATE \ --beta=$BETA \ --optim="$OPTIM" \ --output_dir="$OUTPUT_DIR" \ --run_name="sppo" \ --loss_type=$LOSS_TYPE \ --per_device_train_batch_size=$BATCH_SIZE \ --gradient_accumulation_steps=$ACCUMULATE \ --model_name_or_path=$MODEL \ --num_train_epochs=$NUM # 2>&1 | tee "${log_file}.log" ================================================ FILE: scripts/preload.py ================================================ import llm_blender blender = llm_blender.Blender() blender.loadranker("llm-blender/PairRM") ================================================ FILE: scripts/rank.py ================================================ from datasets import load_dataset import json import pandas as pd import argparse import llm_blender import os import numpy as np from transformers import AutoTokenizer def parse_arguments(): """Parse command line arguments.""" parser = argparse.ArgumentParser() parser.add_argument( "--model", type=str, default="mistralai/Mistral-7B-Instruct-v0.2" ) parser.add_argument('--output_dir', type=str, default='generated/iter1') parser.add_argument("--numgpu", type=int, default=8) parser.add_argument('--prompts', type=str, default='UCLA-AGI/data-mistral-7b-instruct-sppo-iter1') parser.add_argument('--data_frac', type=int, default=0) parser.add_argument('--frac_len', type=int, default=0) parser.add_argument("--gpu", type=int, default=0) # local rank parser.add_argument("--pairs", type=int, default=5) return parser.parse_args() def ranking(args, prompts, candidates): blender = llm_blender.Blender() blender.loadranker("llm-blender/PairRM") ranks = blender.rank(prompts, candidates, return_scores=True, batch_size=1) np.save(f"ranking/{args.output_dir}/{args.gpu}_{args.data_frac}.npy", ranks) def split_prompts(prompts, frac_len, data_frac): if frac_len > 0: split_len = frac_len if split_len * (data_frac + 1) > len(prompts): return prompts[split_len * data_frac:] else: return prompts[split_len * data_frac: split_len * (data_frac + 1)] else: return prompts[:] def apply_template(text, tokenizer): return tokenizer.apply_chat_template( [{"role": "user", "content": text}, {"role": "assistant", "content": "None"}], tokenize=False, add_generate_prompt=True ).split("None")[0] def main(args): data = load_dataset(args.prompts, split="train") if "mistral" in args.model.lower(): tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") elif "llama-3" in args.model.lower(): tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") elif "gemma-2" in args.model.lower(): tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it") else: raise ValueError("Must contain model name in the dataset name. Supported models: Mistral/Llama-3") tokenizer.pad_token = tokenizer.eos_token prompts_all = [apply_template(data[idx]["prompt"], tokenizer) for idx in range(len(data))] print(prompts_all[0]) pairs = args.pairs all_generated = [] for i in range(pairs): file_path = f"generated/{args.output_dir}/responses_{i}.json" with open(file_path) as f: gen = json.load(f) all_generated.append(gen) candidates_texts = list(zip(*all_generated)) assert len(data) == len(candidates_texts) print(f'Length of data: {len(data)}') data_frac = args.data_frac os.makedirs(f"ranking/{args.output_dir}", exist_ok=True) data_frac, frac_len = args.data_frac, args.frac_len prompts_all = split_prompts(prompts_all, frac_len, data_frac) candidates_texts = split_prompts(candidates_texts, frac_len, data_frac) ranking(args, prompts_all, candidates_texts) if __name__ == "__main__": args = parse_arguments() main(args) ================================================ FILE: scripts/update_dataset.py ================================================ import re import argparse parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str) parser.add_argument('--config', type=str, default='recipes/uclaml-sppo/config_full.yaml') args=parser.parse_args() # The path to your configuration file file_path = args.config # New dataset_mixer content you want to insert new_dataset_mixer = f'dataset_mixer:\n {args.dataset}: 1.0' # Read the original content of the file with open(file_path, 'r') as file: content = file.read() # Regular expression to match the dataset_mixer block and replace it # Adjust the pattern if your structure might vary significantly pattern = re.compile(r'dataset_mixer:\n\s*[^:]+:\s*\d+(\.\d+)?') # Replace the matched pattern with the new dataset_mixer content new_content = re.sub(pattern, new_dataset_mixer, content) # Write the modified content back to the file with open(file_path, 'w') as file: file.write(new_content) print("Dataset mixer updated successfully.") ================================================ FILE: setup.cfg ================================================ [isort] default_section = FIRSTPARTY ensure_newline_before_comments = True force_grid_wrap = 0 include_trailing_comma = True known_first_party = alignment known_third_party = transformers datasets fugashi git h5py matplotlib nltk numpy packaging pandas psutil pytest rouge_score sacrebleu seqeval sklearn streamlit torch tqdm line_length = 119 lines_after_imports = 2 multi_line_output = 3 use_parentheses = True [flake8] ignore = E203, E501, E741, W503, W605 max-line-length = 119 per-file-ignores = # imported but unused __init__.py: F401 [tool:pytest] doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS ================================================ FILE: setup.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py # Adapted from alignment-handbook: https://github.com/huggingface/alignment-handbook import re import shutil from pathlib import Path from setuptools import find_packages, setup # Remove stale alignment.egg-info directory to avoid https://github.com/pypa/pip/issues/5466 stale_egg_info = Path(__file__).parent / "alignment.egg-info" if stale_egg_info.exists(): print( ( "Warning: {} exists.\n\n" "If you recently updated alignment, this is expected,\n" "but it may prevent alignment from installing in editable mode.\n\n" "This directory is automatically generated by Python's packaging tools.\n" "I will remove it now.\n\n" "See https://github.com/pypa/pip/issues/5466 for details.\n" ).format(stale_egg_info) ) shutil.rmtree(stale_egg_info) # IMPORTANT: all dependencies should be listed here with their version requirements, if any. # * If a dependency is fast-moving (e.g. transformers), pin to the exact version _deps = [ "accelerate==0.27.2", "bitsandbytes==0.41.2.post2", "black==23.1.0", "datasets==2.14.6", "deepspeed==0.12.2", "einops>=0.6.1", "evaluate==0.4.0", "flake8>=6.0.0", "hf-doc-builder>=0.4.0", "hf_transfer>=0.1.4", "huggingface-hub>=0.19.2,<1.0", "isort>=5.12.0", "ninja>=1.11.1", "numpy==1.26.4", "packaging>=23.0", "parameterized>=0.9.0", "peft==0.7.1", "protobuf<=3.20.2", # Needed to avoid conflicts with `transformers` "pytest", "safetensors>=0.3.3", "sentencepiece>=0.1.99", "scipy", "tensorboard", "torch==2.1.2", "transformers==4.42.4", "trl==0.9.6", "jinja2>=3.0.0", "tqdm>=4.64.1", ] # this is a lookup table with items like: # # tokenizers: "tokenizers==0.9.4" # packaging: "packaging" # # some of the values are versioned whereas others aren't. deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)} def deps_list(*pkgs): return [deps[pkg] for pkg in pkgs] extras = {} extras["tests"] = deps_list("pytest", "parameterized") extras["torch"] = deps_list("torch") extras["quality"] = deps_list("black", "isort", "flake8") extras["docs"] = deps_list("hf-doc-builder") extras["dev"] = extras["docs"] + extras["quality"] + extras["tests"] # core dependencies shared across the whole project - keep this to a bare minimum :) install_requires = [ deps["accelerate"], deps["bitsandbytes"], deps["einops"], deps["evaluate"], deps["datasets"], deps["deepspeed"], deps["hf_transfer"], deps["huggingface-hub"], deps["jinja2"], deps["ninja"], deps["numpy"], deps["packaging"], # utilities from PyPA to e.g., compare versions deps["peft"], deps["protobuf"], deps["safetensors"], deps["sentencepiece"], deps["scipy"], deps["tensorboard"], deps["tqdm"], # progress bars in model download and training scripts deps["transformers"], deps["trl"], ] setup( name="SPPO", version="0.1.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) author="uclaml", description="Self-Play Preference Optimization (SPPO)", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", keywords="llm language-models transformers nlp deep-learning self-play", license="Apache", url="https://github.com/uclaml/SPPO", package_dir={"": "sppo"}, packages=find_packages("sppo"), zip_safe=False, extras_require=extras, python_requires=">=3.10.0", install_requires=install_requires, classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Intended Audience :: Education", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], ) ================================================ FILE: sppo/alignment/__init__.py ================================================ __version__ = "0.3.0.dev0" from .configs import DataArguments, SPPOConfig, H4ArgumentParser, ModelArguments, SFTConfig from .data import apply_chat_template, get_datasets from .model_utils import ( get_checkpoint, get_kbit_device_map, get_peft_config, get_quantization_config, get_tokenizer, is_adapter_model, ) ================================================ FILE: sppo/alignment/configs.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Adapted from https://github.com/huggingface/alignment-handbook import dataclasses import os import sys from dataclasses import dataclass, field from typing import Any, Dict, List, NewType, Optional, Tuple import transformers from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) DataClassType = NewType("DataClassType", Any) class H4ArgumentParser(HfArgumentParser): def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]: """ Parse a YAML file and overwrite the default/loaded values with the values provided to the command line. Args: yaml_arg (`str`): The path to the config file used other_args (`List[str]`, *optional`): A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2']. Returns: [`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line """ arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg)) outputs = [] # strip other args list into dict of key-value pairs other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args} used_args = {} # overwrite the default/loaded value with the value provided to the command line # adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327 for data_yaml, data_class in zip(arg_list, self.dataclass_types): keys = {f.name for f in dataclasses.fields(data_yaml) if f.init} inputs = {k: v for k, v in vars(data_yaml).items() if k in keys} for arg, val in other_args.items(): # add only if in keys if arg in keys: base_type = data_yaml.__dataclass_fields__[arg].type inputs[arg] = val # cast type for ints, floats (default to strings) if base_type in [int, float]: inputs[arg] = base_type(val) if base_type == List[str]: inputs[arg] = [str(v) for v in val.split(",")] # bool of a non-empty string is True, so we manually check for bools if base_type == bool: if val in ["true", "True"]: inputs[arg] = True else: inputs[arg] = False # add to used-args so we can check if double add if arg not in used_args: used_args[arg] = val else: raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior") obj = data_class(**inputs) outputs.append(obj) return outputs def parse(self) -> DataClassType | Tuple[DataClassType]: if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): # If we pass only one argument to the script and it's the path to a YAML file, # let's parse it to get our arguments. output = self.parse_yaml_file(os.path.abspath(sys.argv[1])) # parse command line args and yaml file elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"): output = self.parse_yaml_and_args(os.path.abspath(sys.argv[1]), sys.argv[2:]) # parse command line args only else: output = self.parse_args_into_dataclasses() if len(output) == 1: output = output[0] return output @dataclass class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune. """ base_model_revision: Optional[str] = field( default=None, metadata={"help": ("The base model checkpoint for weights initialization with PEFT adatpers.")}, ) model_name_or_path: Optional[str] = field( default=None, metadata={ "help": ( "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." ) }, ) model_revision: str = field( default="main", metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, ) model_code_revision: str = field(default=None, metadata={"help": "The branch of the IFT model"}) torch_dtype: Optional[str] = field( default=None, metadata={ "help": ( "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " "dtype will be automatically derived from the model's weights." ), "choices": ["auto", "bfloat16", "float16", "float32"], }, ) trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) use_flash_attention_2: bool = field( default=False, metadata={ "help": ( "Whether to use flash attention 2. You must install this manually by running `pip install flash-attn --no-build-isolation`" ) }, ) use_peft: bool = field( default=False, metadata={"help": ("Whether to use PEFT or not for training.")}, ) lora_r: Optional[int] = field( default=16, metadata={"help": ("LoRA R value.")}, ) lora_alpha: Optional[int] = field( default=32, metadata={"help": ("LoRA alpha.")}, ) lora_dropout: Optional[float] = field( default=0.05, metadata={"help": ("LoRA dropout.")}, ) lora_target_modules: Optional[List[str]] = field( default=None, metadata={"help": ("LoRA target modules.")}, ) lora_modules_to_save: Optional[List[str]] = field( default=None, metadata={"help": ("Model layers to unfreeze & train")}, ) load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision"}) load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision"}) bnb_4bit_quant_type: Optional[str] = field( default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"} ) use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) def __post_init__(self): if self.load_in_8bit and self.load_in_4bit: raise ValueError("You can't use 8 bit and 4 bit precision at the same time") @dataclass class DataArguments: """ Arguments pertaining to what data we are going to input our model for training and eval. """ chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."}) dataset_mixer: Optional[Dict[str, float]] = field( default=None, metadata={"help": ("Datasets and their proportions to be used for training ift/rl.")}, ) dataset_splits: Optional[List[str]] = field( default_factory=lambda: ["train", "test"], metadata={"help": ("List of train test splits to use in the dataset")}, ) preprocessing_num_workers: Optional[int] = field( default=None, metadata={"help": "The number of processes to use for the preprocessing."}, ) truncation_side: Optional[str] = field( default=None, metadata={"help": "Truncation side to use for the tokenizer."} ) @dataclass class SFTConfig(transformers.TrainingArguments): """ Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments """ max_seq_length: Optional[int] = field( default=None, metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")}, ) logging_first_step: bool = field( default=True, metadata={"help": ("Whether to log and evaluate the first global_step or not.")}, ) optim: Optional[str] = field(default="adamw_torch") @dataclass class SPPOConfig(transformers.TrainingArguments): """ Arguments related to the SPPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments """ beta: Optional[float] = field( default=0.1, metadata={"help": "The beta factor in DPO loss. In SPPO eta = 1/beta. Higher beta means less divergence from the initial policy."}, ) hub_model_revision: Optional[str] = field( default="main", metadata={"help": ("The Hub model branch to push the model to.")}, ) logging_first_step: bool = field( default=True, metadata={"help": ("Whether to log and evaluate the first global_step or not.")}, ) max_prompt_length: Optional[int] = field( default=None, metadata={"help": ("For DPO/SPPO, the maximum length of the prompt to use for conditioning the model.")}, ) max_length: Optional[int] = field( default=None, metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")}, ) optim: Optional[str] = field(default="rmsprop") remove_unused_columns: bool = field(default=False) loss_type: Optional[str] = field(default="sigmoid", metadata={"help": ("The loss type for SPPO.")}) ================================================ FILE: sppo/alignment/data.py ================================================ #!/usr/bin/env python # # Adapted from https://github.com/huggingface/alignment-handbook import os from typing import List, Literal, Optional from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk from datasets.builder import DatasetGenerationError from .configs import DataArguments DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" def maybe_insert_system_message(messages, tokenizer): if messages[0]["role"] == "system": return # chat template can be one of two attributes, we check in order chat_template = tokenizer.chat_template if chat_template is None: chat_template = tokenizer.default_chat_template # confirm the jinja template refers to a system message before inserting if "system" in chat_template: messages.insert(0, {"role": "system", "content": ""}) def apply_chat_template( example, tokenizer, skip_system_message, ): if all(k in example.keys() for k in ("chosen", "rejected")): prompt_messages = example["chosen"][:-1] # Prepend a system message if the first message is not a system message if not skip_system_message: if example["chosen"][0]["role"] != "system": prompt_messages.insert(0, {"role": "system", "content": ""}) # Now we extract the final turn to define chosen/rejected responses chosen_messages = example["chosen"][-1:] rejected_messages = example["rejected"][-1:] example["text_chosen"] = tokenizer.apply_chat_template( chosen_messages, tokenize=False, add_generate_prompt=True ) example["text_rejected"] = tokenizer.apply_chat_template( rejected_messages, tokenize=False, add_generate_prompt=True ) example["text_prompt"] = tokenizer.apply_chat_template( prompt_messages, tokenize=False, add_generate_prompt=True ) else: prompt_messages = example["chosen"][:-1] chosen_messages = example["chosen"] rejected_messages = example["rejected"] example["text_prompt"] = tokenizer.apply_chat_template( prompt_messages, tokenize=False, add_generate_prompt=True ) example["text_chosen"] = tokenizer.apply_chat_template( chosen_messages, tokenize=False, add_generate_prompt=True )[len(example["text_prompt"]) :] example["text_rejected"] = tokenizer.apply_chat_template( rejected_messages, tokenize=False, add_generate_prompt=True )[len(example["text_prompt"]) :] else: raise ValueError( f"Could not format example as dialogue for `sppo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" ) return example def get_datasets( data_config: DataArguments | dict, splits: List[str] = ["train", "test"], shuffle: bool = True, ) -> DatasetDict: """ Loads one or more datasets with varying training set proportions. Args: data_config (`DataArguments` or `dict`): Dataset configuration and split proportions. splits (`List[str]`, *optional*, defaults to `['train', 'test']`): Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. shuffle (`bool`, *optional*, defaults to `True`): Whether to shuffle the training and testing/validation data. Returns [`DatasetDict`]: The dataset dictionary containing the loaded datasets. """ if type(data_config) is DataArguments: # Structure of the config to read the datasets and their mix # datasets_mixer: # - 'dataset1': 0.5 # - 'dataset2': 0.3 # - 'dataset3': 0.2 dataset_mixer = data_config.dataset_mixer elif isinstance(data_config, dict): # Structure of the input is: # dataset_mixer = { # "dataset1": 0.5, # "dataset1": 0.3, # "dataset1": 0.2, # } dataset_mixer = data_config else: raise ValueError(f"Data config {data_config} not recognized.") raw_datasets = mix_datasets(dataset_mixer, splits=splits, shuffle=shuffle) return raw_datasets def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffle=True) -> DatasetDict: """ Loads and mixes datasets according to proportions specified in `dataset_mixer`. Args: dataset_mixer (`dict`): Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1. splits (Optional[List[str]], *optional*, defaults to `None`): Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. shuffle (`bool`, *optional*, defaults to `True`): Whether to shuffle the training and testing/validation data. """ raw_datasets = DatasetDict() raw_train_datasets = [] raw_val_datasets = [] fracs = [] for ds, frac in dataset_mixer.items(): fracs.append(frac) for split in splits: try: # Try first if dataset on a Hub repo dataset = load_dataset(ds, split=split) except DatasetGenerationError: # If not, check local dataset dataset = load_from_disk(os.path.join(ds, split)) if "train" in split: raw_train_datasets.append(dataset) elif "test" in split: raw_val_datasets.append(dataset) else: raise ValueError(f"Split type {split} not recognized as one of test or train.") if any(frac < 0 for frac in fracs): raise ValueError("Dataset fractions cannot be negative.") if len(raw_train_datasets) > 0: train_subsets = [] for dataset, frac in zip(raw_train_datasets, fracs): train_subset = dataset.select(range(int(frac * len(dataset)))) train_subsets.append(train_subset) if shuffle: raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42) else: raw_datasets["train"] = concatenate_datasets(train_subsets) # No subsampling for test datasets to enable fair comparison across models if len(raw_val_datasets) > 0: if shuffle: raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42) else: raw_datasets["test"] = concatenate_datasets(raw_val_datasets) if len(raw_datasets) == 0: raise ValueError( f"Dataset {dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted." ) return raw_datasets ================================================ FILE: sppo/alignment/model_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Adapted from https://github.com/huggingface/alignment-handbook import os from pathlib import Path from typing import Dict import torch from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer from transformers.trainer_utils import get_last_checkpoint from accelerate import Accelerator from huggingface_hub import list_repo_files from huggingface_hub.utils._errors import RepositoryNotFoundError from huggingface_hub.utils._validators import HFValidationError from peft import LoraConfig, PeftConfig from .configs import DataArguments, SPPOConfig, ModelArguments, SFTConfig from .data import DEFAULT_CHAT_TEMPLATE def get_current_device() -> int: """Get the current device. For GPU we return the local process index to enable multiple GPU training.""" return Accelerator().local_process_index if torch.cuda.is_available() else "cpu" def get_kbit_device_map() -> Dict[str, int] | None: """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`""" return {"": get_current_device()} if torch.cuda.is_available() else None def get_quantization_config(model_args: ModelArguments) -> BitsAndBytesConfig | None: if model_args.load_in_4bit: compute_dtype = torch.float16 if model_args.torch_dtype not in {"auto", None}: compute_dtype = getattr(torch, model_args.torch_dtype) quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, ) elif model_args.load_in_8bit: quantization_config = BitsAndBytesConfig( load_in_8bit=True, ) else: quantization_config = None return quantization_config def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer: """Get the tokenizer for the model.""" tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, revision=model_args.model_revision, ) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id if data_args.truncation_side is not None: tokenizer.truncation_side = data_args.truncation_side # Set reasonable default for models without max length if tokenizer.model_max_length > 100_000: tokenizer.model_max_length = 2048 if data_args.chat_template is not None: tokenizer.chat_template = data_args.chat_template elif tokenizer.chat_template is None and tokenizer.default_chat_template is None: tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE return tokenizer def get_peft_config(model_args: ModelArguments) -> PeftConfig | None: if model_args.use_peft is False: return None peft_config = LoraConfig( r=model_args.lora_r, lora_alpha=model_args.lora_alpha, lora_dropout=model_args.lora_dropout, bias="none", task_type="CAUSAL_LM", target_modules=model_args.lora_target_modules, modules_to_save=model_args.lora_modules_to_save, ) return peft_config def is_adapter_model(model_name_or_path: str, revision: str = "main") -> bool: try: # Try first if model on a Hub repo repo_files = list_repo_files(model_name_or_path, revision=revision) except (HFValidationError, RepositoryNotFoundError): # If not, check local repo repo_files = os.listdir(model_name_or_path) return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files def get_checkpoint(training_args: SFTConfig | SPPOConfig) -> Path | None: last_checkpoint = None if os.path.isdir(training_args.output_dir): last_checkpoint = get_last_checkpoint(training_args.output_dir) return last_checkpoint ================================================ FILE: sppo/alignment/release.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import re import packaging.version REPLACE_PATTERNS = { "init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'), "setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'), } REPLACE_FILES = { "init": "src/alignment/__init__.py", "setup": "setup.py", } README_FILE = "README.md" def update_version_in_file(fname, version, pattern): """Update the version in one file using a specific pattern.""" with open(fname, "r", encoding="utf-8", newline="\n") as f: code = f.read() re_pattern, replace = REPLACE_PATTERNS[pattern] replace = replace.replace("VERSION", version) code = re_pattern.sub(replace, code) with open(fname, "w", encoding="utf-8", newline="\n") as f: f.write(code) def global_version_update(version, patch=False): """Update the version in all needed files.""" for pattern, fname in REPLACE_FILES.items(): update_version_in_file(fname, version, pattern) def get_version(): """Reads the current version in the __init__.""" with open(REPLACE_FILES["init"], "r") as f: code = f.read() default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0] return packaging.version.parse(default_version) def pre_release_work(patch=False): """Do all the necessary pre-release steps.""" # First let's get the default version: base version if we are in dev, bump minor otherwise. default_version = get_version() if patch and default_version.is_devrelease: raise ValueError("Can't create a patch version from the dev branch, checkout a released version!") if default_version.is_devrelease: default_version = default_version.base_version elif patch: default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}" else: default_version = f"{default_version.major}.{default_version.minor + 1}.0" # Now let's ask nicely if that's the right one. version = input(f"Which version are you releasing? [{default_version}]") if len(version) == 0: version = default_version print(f"Updating version to {version}.") global_version_update(version, patch=patch) def post_release_work(): """Do all the necessary post-release steps.""" # First let's get the current version current_version = get_version() dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0" current_version = current_version.base_version # Check with the user we got that right. version = input(f"Which version are we developing now? [{dev_version}]") if len(version) == 0: version = dev_version print(f"Updating version to {version}.") global_version_update(version) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.") parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.") args = parser.parse_args() if not args.post_release: pre_release_work(patch=args.patch) elif args.patch: print("Nothing to do after a patch :-)") else: post_release_work() ================================================ FILE: sppo/run_sft.py ================================================ #!/usr/bin/env python # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Supervised fine-tuning script for decoder language models. """ import logging import random import sys import datasets import torch import transformers from transformers import set_seed from alignment import ( DataArguments, H4ArgumentParser, ModelArguments, SFTConfig, apply_chat_template, get_checkpoint, get_datasets, get_kbit_device_map, get_peft_config, get_quantization_config, get_tokenizer, ) from trl import SFTTrainer logger = logging.getLogger(__name__) def main(): parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig)) model_args, data_args, training_args = parser.parse() # Set seed for reproducibility set_seed(training_args.seed) ############### # Setup logging ############### logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) log_level = training_args.get_process_log_level() logger.setLevel(log_level) datasets.utils.logging.set_verbosity(log_level) transformers.utils.logging.set_verbosity(log_level) transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() # Log on each process a small summary logger.warning( f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) logger.info(f"Model parameters {model_args}") logger.info(f"Data parameters {data_args}") logger.info(f"Training/evaluation parameters {training_args}") # Check for last checkpoint last_checkpoint = get_checkpoint(training_args) if last_checkpoint is not None and training_args.resume_from_checkpoint is None: logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") ############### # Load datasets ############### raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits) logger.info( f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" ) column_names = list(raw_datasets["train"].features) ################ # Load tokenizer ################ tokenizer = get_tokenizer(model_args, data_args) ##################### # Apply chat template ##################### raw_datasets = raw_datasets.map( apply_chat_template, fn_kwargs={"tokenizer": tokenizer, "task": "sft"}, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, desc="Applying chat template", ) train_dataset = raw_datasets["train"] eval_dataset = raw_datasets["test"] with training_args.main_process_first(desc="Log a few random samples from the processed training set"): for index in random.sample(range(len(raw_datasets["train"])), 3): logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}") ####################### # Load pretrained model ####################### logger.info("*** Load pretrained model ***") torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, use_flash_attention_2=model_args.use_flash_attention_2, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) logger.info("*** Model loaded! ***") ######################## # Initialize the Trainer ######################## trainer = SFTTrainer( model=model_args.model_name_or_path, model_init_kwargs=model_kwargs, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, dataset_text_field="text", max_seq_length=training_args.max_seq_length, tokenizer=tokenizer, packing=True, peft_config=get_peft_config(model_args), ) ############### # Training loop ############### logger.info("*** Train ***") checkpoint = None if training_args.resume_from_checkpoint is not None: checkpoint = training_args.resume_from_checkpoint elif last_checkpoint is not None: checkpoint = last_checkpoint train_result = trainer.train(resume_from_checkpoint=checkpoint) metrics = train_result.metrics metrics["train_samples"] = len(train_dataset) trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() ########## # Evaluate ########## if training_args.do_eval: logger.info("*** Evaluate ***") metrics = trainer.evaluate() metrics["eval_samples"] = len(eval_dataset) trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) ################################## # Save model and create model card ################################## logger.info("*** Save model ***") trainer.save_model(training_args.output_dir) logger.info(f"Model saved to {training_args.output_dir}") # Save everything else on main process kwargs = { "finetuned_from": model_args.model_name_or_path, "dataset": list(data_args.dataset_mixer.keys()), "dataset_tags": list(data_args.dataset_mixer.keys()), "tags": ["alignment-handbook"], } if trainer.accelerator.is_main_process: trainer.create_model_card(**kwargs) # Restore k,v cache for fast inference trainer.model.config.use_cache = True trainer.model.config.save_pretrained(training_args.output_dir) if training_args.push_to_hub is True: logger.info("Pushing to hub...") trainer.push_to_hub(**kwargs) logger.info("*** Training complete ***") if __name__ == "__main__": main() ================================================ FILE: sppo/run_sppo.py ================================================ #!/usr/bin/env python # # Adapted from https://github.com/huggingface/alignment-handbook import logging import random import sys import yaml import torch import transformers from transformers import AutoModelForCausalLM, set_seed from alignment import ( DataArguments, SPPOConfig, H4ArgumentParser, ModelArguments, apply_chat_template, get_checkpoint, get_datasets, get_kbit_device_map, get_peft_config, get_quantization_config, get_tokenizer, is_adapter_model, ) from peft import PeftConfig, PeftModel from trainer import SPPOTrainer logger = logging.getLogger(__name__) def load_config(config_path): with open(config_path, 'r') as config_file: return yaml.safe_load(config_file) def setup_logging(log_level): logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) logger.setLevel(log_level) transformers.utils.logging.set_verbosity(log_level) transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() def load_and_process_datasets(data_args, tokenizer): raw_datasets = get_datasets(data_args, splits=["train"]) logger.info( f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" ) column_names = list(raw_datasets["train"].features) column_names = [x for x in column_names if x not in ['chosen_probs', 'chosen_probs_win', 'chosen_probs_lose']] raw_datasets = raw_datasets.map( apply_chat_template, fn_kwargs={"tokenizer": tokenizer, "skip_system_message": True}, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, desc="Formatting comparisons with prompt template", ) for split in ["train"]: raw_datasets[split] = raw_datasets[split].rename_columns( {"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"} ) return raw_datasets def setup_model(model_args, training_args): torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, use_flash_attention_2=model_args.use_flash_attention_2, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) model = model_args.model_name_or_path if is_adapter_model(model, model_args.model_revision): logger.info(f"Loading SFT adapter for {model_args.model_name_or_path=}") peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision) model_kwargs = dict( revision=model_args.base_model_revision, trust_remote_code=model_args.trust_remote_code, use_flash_attention_2=model_args.use_flash_attention_2, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) base_model = AutoModelForCausalLM.from_pretrained( peft_config.base_model_name_or_path, **model_kwargs, ) model = PeftModel.from_pretrained( base_model, model_args.model_name_or_path, revision=model_args.model_revision, ) model_kwargs = None ref_model = model ref_model_kwargs = model_kwargs if model_args.use_peft: ref_model = None ref_model_kwargs = None return model, ref_model, model_kwargs, ref_model_kwargs def train_and_evaluate(trainer, raw_datasets, training_args): checkpoint = None train_result = trainer.train(resume_from_checkpoint=checkpoint) metrics = train_result.metrics metrics["train_samples"] = len(raw_datasets["train"]) trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() logger.info("*** Training complete ***") if training_args.do_eval: logger.info("*** Evaluate ***") metrics = trainer.evaluate() metrics["eval_samples"] = len(raw_datasets["test"]) trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) def save_model_and_results(trainer, training_args, model_args, data_args): logger.info("*** Save model ***") trainer.save_model(training_args.output_dir) logger.info(f"Model saved to {training_args.output_dir}") kwargs = { "finetuned_from": model_args.model_name_or_path, "dataset": list(data_args.dataset_mixer.keys()), "dataset_tags": list(data_args.dataset_mixer.keys()), "tags": ["alignment-handbook"], } if trainer.accelerator.is_main_process: trainer.create_model_card(**kwargs) trainer.model.config.use_cache = True trainer.model.config.save_pretrained(training_args.output_dir) if training_args.push_to_hub: logger.info("Pushing to hub...") trainer.push_to_hub(**kwargs) trainer.accelerator.wait_for_everyone() logger.info("*** Training complete! ***") def main(): parser = H4ArgumentParser((ModelArguments, DataArguments, SPPOConfig)) model_args, data_args, training_args = parser.parse() training_args.do_eval = False num_iteration = 1 try: for i in range(num_iteration): main_inner(model_args, data_args, training_args) print(f"-------------------------Finished Iteration {i+1}---------------------------------") except Exception as e: logger.error(f"An error occurred: {str(e)}") raise def main_inner(model_args, data_args, training_args): setup_logging(training_args.get_process_log_level()) logger.info(f"Model parameters {model_args}") logger.info(f"Data parameters {data_args}") logger.info(f"Training/evaluation parameters {training_args}") last_checkpoint = get_checkpoint(training_args) if last_checkpoint is not None and training_args.resume_from_checkpoint is None: logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") set_seed(training_args.seed) data_args.truncation_side = "left" tokenizer = get_tokenizer(model_args, data_args) raw_datasets = load_and_process_datasets(data_args, tokenizer) model, ref_model, model_kwargs, ref_model_kwargs = setup_model(model_args, training_args) trainer = SPPOTrainer( model, ref_model, model_init_kwargs=model_kwargs, ref_model_init_kwargs=ref_model_kwargs, args=training_args, beta=training_args.beta, train_dataset=raw_datasets["train"], tokenizer=tokenizer, max_length=training_args.max_length, max_prompt_length=training_args.max_prompt_length, peft_config=get_peft_config(model_args), loss_type=training_args.loss_type, ) train_and_evaluate(trainer, raw_datasets, training_args) save_model_and_results(trainer, training_args, model_args, data_args) if __name__ == "__main__": main() ================================================ FILE: sppo/trainer.py ================================================ #!/usr/bin/env python # # Adapted from https://github.com/huggingface/alignment-handbook import inspect import random import warnings from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import deepcopy from functools import wraps from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from accelerate.utils import is_deepspeed_available, tqdm from datasets import Dataset from torch.utils.data import DataLoader from transformers import ( AutoModelForCausalLM, DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments, ) from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput from trl.import_utils import is_peft_available, is_wandb_available from trl.models import PreTrainedModelWrapper, create_reference_model from trl.trainer.utils import ( DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length, trl_sanitze_kwargs_for_tagging, ) if is_peft_available(): from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training if is_wandb_available(): import wandb if is_deepspeed_available(): import deepspeed # def add_cols(feature, chosen_probs, chosen_probs_win, chosen_probs_lose): # feature['chosen_probs'] = chosen_probs # feature['chosen_probs_win'] = chosen_probs_win # feature['chosen_probs_lose'] = chosen_probs_lose # return feature class SPPOTrainer(Trainer): r""" Initialize SPPOTrainer. Args: model (`transformers.PreTrainedModel`): The model to train, preferably an `AutoModelForSequenceClassification`. ref_model (`PreTrainedModelWrapper`): Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. beta (`float`, defaults to 0.1): The beta factor in DPO loss. In SPPO, eta=1/beta. Higher beta means less divergence from the initial policy. For the IPO loss, beta is the regularization parameter denoted by tau in the paper. label_smoothing (`float`, defaults to 0): The robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report that should be between 0 and 0.5. loss_type (`str`, defaults to `"sigmoid"`): The type of loss to use. 'sppo' reproduces the SPPO algorithms. Other choices are explained as follows: `"sigmoid"` represents the default DPO loss,`"hinge"` loss from [SLiC](https://arxiv.org/abs/2305.10425) paper, `"ipo"` from [IPO](https://arxiv.org/abs/2310.12036) paper, or `"kto"` from the HALOs [report](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf). args (`transformers.TrainingArguments`): The arguments to use for training. data_collator (`transformers.DataCollator`): The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. label_pad_token_id (`int`, defaults to `-100`): The label pad token id. This argument is required if you want to use the default data collator. padding_value (`int`, defaults to `0`): The padding value if it is different to the tokenizer's pad_token_id. truncation_mode (`str`, defaults to `keep_end`): The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator. train_dataset (`datasets.Dataset`): The dataset to use for training. eval_dataset (`datasets.Dataset`): The dataset to use for evaluation. tokenizer (`transformers.PreTrainedTokenizerBase`): The tokenizer to use for training. This argument is required if you want to use the default data collator. model_init (`Callable[[], transformers.PreTrainedModel]`): The model initializer to use for training. If None is specified, the default model initializer will be used. callbacks (`List[transformers.TrainerCallback]`): The callbacks to use for training. optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): The optimizer and scheduler to use for training. preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): The function to use to preprocess the logits before computing the metrics. max_length (`int`, defaults to `None`): The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. max_prompt_length (`int`, defaults to `None`): The maximum length of the prompt. This argument is required if you want to use the default data collator. max_target_length (`int`, defaults to `None`): The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder. peft_config (`Dict`, defaults to `None`): The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`): If no model is provided, we need to know if the model_init returns an encoder-decoder. disable_dropout (`bool`, defaults to `True`): Whether or not to disable dropouts in `model` and `ref_model`. generate_during_eval (`bool`, defaults to `False`): Whether to sample and log generations during evaluation step. compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values. precompute_ref_log_probs (`bool`, defaults to `False`): Flag to precompute reference model log probabilities and evaluation datasets. This is useful if you want to train without the reference model and reduce the total GPU memory needed. model_init_kwargs: (`Optional[Dict]`, *optional*): Dict of Optional kwargs to pass when instantiating the model from a string ref_model_init_kwargs: (`Optional[Dict]`, *optional*): Dict of Optional kwargs to pass when instantiating the ref model from a string model_adapter_name (`str`, defaults to `None`): Name of the train target PEFT adapter, when using LoRA with multiple adapters. ref_adapter_name (`str`, defaults to `None`): Name of the reference PEFT adapter, when using LoRA with multiple adapters. """ _tag_names = ["trl", "sppo"] def __init__( self, model: Union[PreTrainedModel, nn.Module, str] = None, ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, beta: float = 0.1, label_smoothing: float = 0, loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid", args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, label_pad_token_id: int = -100, padding_value: int = 0, truncation_mode: str = "keep_end", train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, max_length: Optional[int] = None, max_prompt_length: Optional[int] = None, max_target_length: Optional[int] = None, peft_config: Optional[Dict] = None, is_encoder_decoder: Optional[bool] = None, disable_dropout: bool = True, generate_during_eval: bool = False, compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, precompute_ref_log_probs: bool = False, model_init_kwargs: Optional[Dict] = None, ref_model_init_kwargs: Optional[Dict] = None, model_adapter_name: str = None, ref_adapter_name: str = None, ): if model_init_kwargs is None: model_init_kwargs = {} elif not isinstance(model, str): raise ValueError("You passed model_kwargs to the SPPOTrainer. But your model is already instantiated.") if ref_model_init_kwargs is None: ref_model_init_kwargs = {} elif not isinstance(ref_model, str): raise ValueError( "You passed ref_model_kwargs to the SPPOTrainer. But your ref_model is already instantiated." ) if isinstance(model, str): warnings.warn( "You passed a model_id to the SPPOTrainer. This will automatically create an " "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." ) model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) if isinstance(ref_model, str): warnings.warn( "You passed a ref model_id to the SPPOTrainer. This will automatically create an " "`AutoModelForCausalLM`" ) ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs) # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` # has been called in order to properly call autocast if needed. self._peft_has_been_casted_to_bf16 = False if not is_peft_available() and peft_config is not None: raise ValueError( "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" ) elif is_peft_available() and peft_config is not None: raise NotImplementedError # # if model is a peft model and we have a peft_config, we merge and unload it first # if isinstance(model, PeftModel): # model = model.merge_and_unload() # if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # _support_gc_kwargs = hasattr( # args, "gradient_checkpointing_kwargs" # ) and "gradient_checkpointing_kwargs" in list( # inspect.signature(prepare_model_for_kbit_training).parameters # ) # preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} # if _support_gc_kwargs: # preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs # model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) # elif getattr(args, "gradient_checkpointing", False): # # For backward compatibility with older versions of transformers # if hasattr(model, "enable_input_require_grads"): # model.enable_input_require_grads() # else: # def make_inputs_require_grad(module, input, output): # output.requires_grad_(True) # model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) # # get peft model with the given config # model = get_peft_model(model, peft_config) # if args.bf16 and getattr(model, "is_loaded_in_4bit", False): # peft_module_casting_to_bf16(model) # # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager # self._peft_has_been_casted_to_bf16 = True # For models that use gradient_checkpoiting, we need to attach a hook that enables input # to explicitly have `requires_grad=True`, otherwise training will either silently # fail or completely fail. elif getattr(args, "gradient_checkpointing", False): # For backward compatibility with older versions of transformers if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) if generate_during_eval and not is_wandb_available(): raise ValueError( "`generate_during_eval=True` requires Weights and Biases to be installed." " Please install `wandb` to resolve." ) if model is not None: self.is_encoder_decoder = model.config.is_encoder_decoder elif is_encoder_decoder is None: raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") else: self.is_encoder_decoder = is_encoder_decoder self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) self.model_adapter_name = model_adapter_name self.ref_adapter_name = ref_adapter_name if ref_model: self.ref_model = ref_model elif self.is_peft_model or precompute_ref_log_probs: # The `model` with adapters turned off will be used as the reference model self.ref_model = None else: self.ref_model = create_reference_model(model) if tokenizer is None: raise ValueError("tokenizer must be specified to tokenize a SPPO dataset.") if max_length is None: warnings.warn( "`max_length` is not set in the SPPOTrainer's init" " it will default to `512` by default, but you should do it yourself in the future.", UserWarning, ) max_length = 512 if max_prompt_length is None: warnings.warn( "`max_prompt_length` is not set in the SPPOTrainer's init" " it will default to `128` by default, but you should do it yourself in the future.", UserWarning, ) max_prompt_length = 128 if max_target_length is None and self.is_encoder_decoder: warnings.warn( "When using an encoder decoder architecture, you should set `max_target_length` in the SPPOTrainer's init" " it will default to `128` by default, but you should do it yourself in the future.", UserWarning, ) max_target_length = 128 if data_collator is None: data_collator = DPODataCollatorWithPadding( pad_token_id=tokenizer.pad_token_id, label_pad_token_id=label_pad_token_id, is_encoder_decoder=self.is_encoder_decoder, ) if args.remove_unused_columns: args.remove_unused_columns = False # warn users warnings.warn( "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" " we have set it for you, but you should do it yourself in the future.", UserWarning, ) self.use_dpo_data_collator = True else: self.use_dpo_data_collator = False if disable_dropout: disable_dropout_in_model(model) if self.ref_model is not None: disable_dropout_in_model(self.ref_model) self.max_length = max_length self.generate_during_eval = generate_during_eval self.label_pad_token_id = label_pad_token_id self.padding_value = padding_value if padding_value is not None else tokenizer.pad_token_id self.max_prompt_length = max_prompt_length self.truncation_mode = truncation_mode self.max_target_length = max_target_length self.tokenizer = tokenizer self.precompute_ref_log_probs = precompute_ref_log_probs # Since ref_logs are precomputed on the first call to get_train/eval_dataloader # keep track of first called to avoid computation of future calls self._precomputed_train_ref_log_probs = False self._precomputed_eval_ref_log_probs = False if loss_type in ["hinge", "ipo", "kto_pair"] and label_smoothing > 0: warnings.warn( "You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter." ) self.beta = beta self.label_smoothing = label_smoothing self.loss_type = loss_type self._stored_metrics = defaultdict(lambda: defaultdict(list)) # tokenize the dataset # print('=== before map', train_dataset.features) # chosen_probs = train_dataset['chosen_probs'] # chosen_probs_win = train_dataset['chosen_probs_win'] # chosen_probs_lose = train_dataset['chosen_probs_lose'] # old_train_dataset = train_dataset train_dataset = train_dataset.map(self.tokenize_row) # print('=== before add', train_dataset.features) # import pandas as pd # mid_dataset = pd.DataFrame(train_dataset) # mid_dataset['chosen_probs'] = chosen_probs # mid_dataset['chosen_probs_win'] = chosen_probs_win # mid_dataset['chosen_probs_lose'] = chosen_probs_lose # train_dataset = Dataset.from_pandas(mid_dataset) # print('=== after add', train_dataset.features) if eval_dataset is not None: eval_dataset = eval_dataset.map(self.tokenize_row) #print('=========') super().__init__( model=model, args=args, data_collator=data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer, model_init=model_init, compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) if not hasattr(self, "accelerator"): raise AttributeError( "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." ) # Deepspeed Zero-3 does not support precompute_ref_log_probs if self.is_deepspeed_enabled: if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: raise ValueError( "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." ) if self.ref_model is None: if not (self.is_peft_model or self.precompute_ref_log_probs): raise ValueError( "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" ) else: if self.is_deepspeed_enabled: self.ref_model = self._prepare_deepspeed(self.ref_model) else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) def _prepare_deepspeed(self, model: PreTrainedModelWrapper): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 deepspeed_plugin = self.accelerator.state.deepspeed_plugin config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) if model is not None: if hasattr(model, "config"): hidden_size = ( max(model.config.hidden_sizes) if getattr(model.config, "hidden_sizes", None) else getattr(model.config, "hidden_size", None) ) if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 config_kwargs.update( { "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, } ) # If ZeRO-3 is used, we shard both the active and reference model. # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) if config_kwargs["zero_optimization"]["stage"] != 3: config_kwargs["zero_optimization"]["stage"] = 0 model, *_ = deepspeed.initialize(model=model, config=config_kwargs) model.eval() return model def get_train_dataloader(self) -> DataLoader: """ Returns the training [`~torch.utils.data.DataLoader`]. Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`. """ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: dataloader_params = { "batch_size": self.args.per_device_train_batch_size, "collate_fn": self.data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, "shuffle": False, } # prepare dataloader data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) reference_chosen_logps = [] reference_rejected_logps = [] for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch) reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics( (reference_chosen_logp, reference_rejected_logp) ) reference_chosen_logps.append(reference_chosen_logp.cpu()) reference_rejected_logps.append(reference_rejected_logp.cpu()) all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy() all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy() self.train_dataset = self.train_dataset.add_column( name="reference_chosen_logps", column=all_reference_chosen_logps ) self.train_dataset = self.train_dataset.add_column( name="reference_rejected_logps", column=all_reference_rejected_logps ) self._precomputed_train_ref_log_probs = True return super().get_train_dataloader() def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: """ Returns the evaluation [`~torch.utils.data.DataLoader`]. Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`. Args: eval_dataset (`torch.utils.data.Dataset`, *optional*): If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. It must implement `__len__`. """ if eval_dataset is None and self.eval_dataset is None: raise ValueError("Trainer: evaluation requires an eval_dataset.") eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: dataloader_params = { "batch_size": self.args.per_device_eval_batch_size, "collate_fn": self.data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, "shuffle": False, } # prepare dataloader data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) reference_chosen_logps = [] reference_rejected_logps = [] for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch) reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics( (reference_chosen_logp, reference_rejected_logp) ) reference_chosen_logps.append(reference_chosen_logp.cpu()) reference_rejected_logps.append(reference_rejected_logp.cpu()) all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy() all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy() eval_dataset = eval_dataset.add_column(name="reference_chosen_logps", column=all_reference_chosen_logps) eval_dataset = eval_dataset.add_column( name="reference_rejected_logps", column=all_reference_rejected_logps ) # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs if self.eval_dataset is not None: self.eval_dataset = eval_dataset self._precomputed_eval_ref_log_probs = True return super().get_eval_dataloader(eval_dataset=eval_dataset) def build_tokenized_answer(self, prompt, answer): """ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`. Reference: https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 """ full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False) prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :] answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :] # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids]) # Prepare input tokens for token by token comparison full_input_ids = np.array(full_tokenized["input_ids"]) if len(full_input_ids) != len(full_concat_input_ids): raise ValueError("Prompt input ids and answer input ids should have the same length.") # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens # can be merged together when tokenizing prompt+answer. This could result # on the last token from the prompt being different when tokenized on its own # vs when done as prompt+answer. response_token_ids_start_idx = len(prompt_input_ids) # If tokenized prompt is different than both prompt+answer, then it means the # last token has changed due to merging. if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]: response_token_ids_start_idx -= 1 prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] if len(prompt_input_ids) != len(prompt_attention_mask): raise ValueError("Prompt input ids and attention mask should have the same length.") answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] return dict( prompt_input_ids=prompt_input_ids, prompt_attention_mask=prompt_attention_mask, input_ids=answer_input_ids, attention_mask=answer_attention_mask, ) def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) -> Dict: """Tokenize a single row from a SPPO specific dataset. At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt + chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long, we truncate the chosen/rejected. We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens. """ batch = {} prompt = feature["prompt"] chosen = feature["chosen"] rejected = feature["rejected"] if not self.is_encoder_decoder: # Check issues below for more details # 1. https://github.com/huggingface/trl/issues/907 # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 # 3. https://github.com/LianjiaTech/BELLE/issues/337 if not isinstance(prompt, str): raise ValueError(f"prompt should be an str but got {type(prompt)}") prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} if not isinstance(chosen, str): raise ValueError(f"chosen should be an str but got {type(chosen)}") chosen_tokens = self.build_tokenized_answer(prompt, chosen) if not isinstance(rejected, str): raise ValueError(f"rejected should be an str but got {type(rejected)}") rejected_tokens = self.build_tokenized_answer(prompt, rejected) # Last prompt token might get merged by tokenizer and # it should not be included for generation if that happens prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"]) chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"]) rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"]) prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids) for k, v in prompt_tokens.items(): prompt_tokens[k] = v[:prompt_len_input_ids] # Make sure prompts only have one different token at most an # and length only differs by 1 at most num_diff_tokens = sum( [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])] ) num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids) if num_diff_tokens > 1 or num_diff_len > 1: raise ValueError( "Chosen and rejected prompt_input_ids might only differ on the " "last token due to tokenizer merge ops." ) # add BOS token to head of prompt prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"] chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"] rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"] prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] # add EOS token to end of answer chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) chosen_tokens["attention_mask"].append(1) rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) rejected_tokens["attention_mask"].append(1) longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) # if combined sequence is too long, truncate the prompt for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: if self.truncation_mode == "keep_start": for k in ["prompt_input_ids", "prompt_attention_mask"]: answer_tokens[k] = answer_tokens[k][: self.max_prompt_length] elif self.truncation_mode == "keep_end": for k in ["prompt_input_ids", "prompt_attention_mask"]: answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :] else: raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") # if that's still too long, truncate the response for answer_tokens in [chosen_tokens, rejected_tokens]: if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: for k in ["input_ids", "attention_mask"]: answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length] # Create labels chosen_sequence_tokens = { k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"] } rejected_sequence_tokens = { k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"] } chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [ self.label_pad_token_id ] * len(chosen_tokens["prompt_input_ids"]) rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [ self.label_pad_token_id ] * len(rejected_tokens["prompt_input_ids"]) for k, toks in { "chosen_": chosen_sequence_tokens, "rejected_": rejected_sequence_tokens, "": prompt_tokens, }.items(): for type_key, tokens in toks.items(): if type_key == "token_type_ids": continue batch[f"{k}{type_key}"] = tokens else: chosen_tokens = self.tokenizer( chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True ) rejected_tokens = self.tokenizer( rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True ) prompt_tokens = self.tokenizer( prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True ) batch["chosen_labels"] = chosen_tokens["input_ids"] batch["rejected_labels"] = rejected_tokens["input_ids"] batch["prompt_input_ids"] = prompt_tokens["input_ids"] batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( labels=batch["rejected_labels"] ) batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( labels=batch["chosen_labels"] ) #print('batch=======', batch.keys()) return batch @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" with self.accelerator.unwrap_model( self.model ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): if self.ref_adapter_name: self.model.set_adapter(self.ref_adapter_name) yield if self.ref_adapter_name: self.model.set_adapter(self.model_adapter_name or "default") def compute_reference_log_probs(self, padded_batch: Dict) -> Dict: """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.""" compte_ref_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext # compute reference logps with torch.no_grad(), compte_ref_context_manager(): if self.ref_model is None: with self.null_ref_context(): ( reference_chosen_logps, reference_rejected_logps, _, _, ) = self.concatenated_forward(self.model, padded_batch) else: ( reference_chosen_logps, reference_rejected_logps, _, _, ) = self.concatenated_forward(self.ref_model, padded_batch) return reference_chosen_logps, reference_rejected_logps @staticmethod def concatenated_inputs( batch: Dict[str, Union[List, torch.LongTensor]], is_encoder_decoder: bool = False, label_pad_token_id: int = -100, padding_value: int = 0, device: Optional[torch.device] = None, ) -> Dict[str, torch.LongTensor]: """Concatenate the chosen and rejected inputs into a single tensor. Args: batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). is_encoder_decoder: Whether the model is an encoder-decoder model. label_pad_token_id: The label pad token id. padding_value: The padding value to use for the concatenated inputs_ids. device: The device for the concatenated inputs. Returns: A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. """ concatenated_batch = {} if is_encoder_decoder: max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) else: max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) for k in batch: if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): if "labels" in k or is_encoder_decoder: pad_value = label_pad_token_id elif k.endswith("_input_ids"): pad_value = padding_value elif k.endswith("_attention_mask"): pad_value = 0 concatenated_key = k.replace("chosen", "concatenated") concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) for k in batch: if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): if "labels" in k or is_encoder_decoder: pad_value = label_pad_token_id elif k.endswith("_input_ids"): pad_value = padding_value elif k.endswith("_attention_mask"): pad_value = 0 concatenated_key = k.replace("rejected", "concatenated") concatenated_batch[concatenated_key] = torch.cat( ( concatenated_batch[concatenated_key], pad_to_length(batch[k], max_length, pad_value=pad_value), ), dim=0, ).to(device=device) if is_encoder_decoder: concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) concatenated_batch["concatenated_attention_mask"] = ( batch["prompt_attention_mask"].repeat(2, 1).to(device=device) ) return concatenated_batch def sppo_loss( self, policy_chosen_logps: torch.FloatTensor, policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor, chosen_probs: Union[torch.FloatTensor, None] = None, chosen_probs_win: Union[torch.FloatTensor, None] = None, chosen_probs_lose: Union[torch.FloatTensor, None] = None, reference_free: bool = False, ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Compute the SPPO loss for a batch of policy and reference model log probabilities. Args: policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. Returns: A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the SPPO loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. """ pi_logratios = policy_chosen_logps - policy_rejected_logps if reference_free: ref_logratios = 0 else: ref_logratios = reference_chosen_logps - reference_rejected_logps pi_logratios = pi_logratios.to(self.accelerator.device) ref_logratios = ref_logratios.to(self.accelerator.device) logits = pi_logratios - ref_logratios # For sppo logits_w = policy_chosen_logps - reference_chosen_logps logits_l = policy_rejected_logps - reference_rejected_logps # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. In SPPO, beta=1/eta has a different meaning, and is usually chosen around 1e-3. # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and # calculates a conservative SPPO loss. if self.loss_type == "sigmoid": losses = ( -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - F.logsigmoid(-self.beta * logits) * self.label_smoothing ) elif self.loss_type == "hinge": losses = torch.relu(1 - self.beta * logits) elif self.loss_type == "ipo": # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. losses = (logits - 1 / (2 * self.beta)) ** 2 elif self.loss_type == "sppo": loss_w = (logits_w - (1 / self.beta)*(chosen_probs_win - 0.5)) ** 2 loss_l = (logits_l - (1 / self.beta)*(chosen_probs_lose - 0.5)) ** 2 losses = (loss_w + loss_l)/2 elif self.loss_type == "sppo_single": loss_w = (logits_w - (1 / self.beta)*(chosen_probs - 0.5)) ** 2 loss_l = (logits_l + (1 / self.beta)*(chosen_probs - 0.5)) ** 2 losses = (loss_w + loss_l)/2 elif self.loss_type == "kto_pair": # eqn (7) of the HALOs paper chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0) rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0) chosen_logratios = policy_chosen_logps - reference_chosen_logps rejected_logratios = policy_rejected_logps - reference_rejected_logps # As described in the KTO report, the KL term for chosen (rejected) is estimated using the rejected (chosen) half. losses = torch.cat( ( 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)), 1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)), ), 0, ) else: raise ValueError( f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']" ) chosen_rewards = ( self.beta * ( policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device) ).detach() ) rejected_rewards = ( self.beta * ( policy_rejected_logps.to(self.accelerator.device) - reference_rejected_logps.to(self.accelerator.device) ).detach() ) return losses, chosen_rewards, rejected_rewards @staticmethod def get_batch_logps( logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False, label_pad_token_id: int = -100, is_encoder_decoder: bool = False, ) -> torch.FloatTensor: """Compute the log probabilities of the given labels under the given logits. Args: logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length) average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. label_pad_token_id: The label pad token id. is_encoder_decoder: Whether the model is an encoder-decoder model. Returns: A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. """ if logits.shape[:-1] != labels.shape: raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") if not is_encoder_decoder: labels = labels[:, 1:].clone() logits = logits[:, :-1, :] loss_mask = labels != label_pad_token_id # dummy token; we'll ignore the losses on these tokens later labels[labels == label_pad_token_id] = 0 per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) else: return (per_token_logps * loss_mask).sum(-1) def concatenated_forward( self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. We do this to avoid doing two forward passes, because it's faster for FSDP. """ concatenated_batch = self.concatenated_inputs( batch, is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, padding_value=self.padding_value, device=self.accelerator.device, ) len_chosen = batch["chosen_labels"].shape[0] model_kwargs = ( { "labels": concatenated_batch["concatenated_labels"], "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None), } if self.is_encoder_decoder else {} ) all_logits = model( concatenated_batch["concatenated_input_ids"], attention_mask=concatenated_batch["concatenated_attention_mask"], **model_kwargs, ).logits all_logps = self.get_batch_logps( all_logits, concatenated_batch["concatenated_labels"], average_log_prob=False, is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, ) chosen_logps = all_logps[:len_chosen] rejected_logps = all_logps[len_chosen:] chosen_logits = all_logits[:len_chosen] rejected_logits = all_logits[len_chosen:] return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) def get_batch_loss_metrics( self, model, batch: Dict[str, Union[List, torch.LongTensor]], train_eval: Literal["train", "eval"] = "train", ): """Compute the SPPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} ( policy_chosen_logps, policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, ) = self.concatenated_forward(model, batch) chosen_probs = torch.tensor(batch["chosen_probs"], dtype=float, device=policy_chosen_logps.device) chosen_probs_win = torch.tensor(batch["chosen_probs_win"], dtype=float, device=policy_chosen_logps.device) chosen_probs_lose = torch.tensor(batch["chosen_probs_lose"], dtype=float, device=policy_chosen_logps.device) # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch: reference_chosen_logps = batch["reference_chosen_logps"] reference_rejected_logps = batch["reference_rejected_logps"] else: with torch.no_grad(): if self.ref_model is None: with self.null_ref_context(): ( reference_chosen_logps, reference_rejected_logps, _, _, ) = self.concatenated_forward(self.model, batch) else: ( reference_chosen_logps, reference_rejected_logps, _, _, ) = self.concatenated_forward(self.ref_model, batch) losses, chosen_rewards, rejected_rewards = self.sppo_loss( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, chosen_probs, chosen_probs_win, chosen_probs_lose, # rejected_probs, ) reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = "eval_" if train_eval == "eval" else "" metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu() metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu() metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu() metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu() metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu() metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu() metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() return losses.mean(), metrics def compute_loss( self, model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs=False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: if not self.use_dpo_data_collator: warnings.warn( "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" ) compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext with compute_loss_context_manager(): loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") # force log the metrics self.store_metrics(metrics, train_eval="train") if return_outputs: return (loss, metrics) return loss def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with # the torch cuda amp context manager as some hidden states are silently casted to full precision. generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast with generate_context_manager(): policy_output = model.generate( input_ids=batch["prompt_input_ids"], attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id, ) # if reference_output in batch use that otherwise use the reference model if "reference_output" in batch: reference_output = batch["reference_output"] else: if self.ref_model is None: with self.null_ref_context(): reference_output = self.model.generate( input_ids=batch["prompt_input_ids"], attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id, ) else: reference_output = self.ref_model.generate( input_ids=batch["prompt_input_ids"], attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id, ) policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id) policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True) reference_output = pad_to_length(reference_output, self.max_length, self.tokenizer.pad_token_id) reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True) return policy_output_decoded, reference_output_decoded def prediction_step( self, model: Union[PreTrainedModel, nn.Module], inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, ): if not self.use_dpo_data_collator: warnings.warn( "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" ) if ignore_keys is None: if hasattr(model, "config"): ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) else: ignore_keys = [] prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext with torch.no_grad(), prediction_context_manager(): loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") # force log the metrics self.store_metrics(metrics, train_eval="eval") if prediction_loss_only: return (loss.detach(), None, None) # logits for the chosen and rejected samples from model logits_dict = { "eval_logits/chosen": metrics["eval_logits/chosen"], "eval_logits/rejected": metrics["eval_logits/rejected"], } logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys) logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device) labels = torch.zeros(logits.shape[0], device=self.accelerator.device) return (loss.detach(), logits, labels) def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: for key, value in metrics.items(): self._stored_metrics[train_eval][key].append(value) def evaluation_loop( self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", ) -> EvalLoopOutput: """ Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. Works both with or without labels. """ # Sample and save to game log if requested (for one batch to save time) if self.generate_during_eval: # Generate random indices within the range of the total number of samples num_samples = len(dataloader.dataset) random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader random_batch_dataset = dataloader.dataset.select(random_indices) random_batch = self.data_collator(random_batch_dataset) random_batch = self._prepare_inputs(random_batch) policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch) self.log( { "game_log": wandb.Table( columns=["Prompt", "Policy", "Ref Model"], rows=[ [prompt, pol[len(prompt) :], ref[len(prompt) :]] for prompt, pol, ref in zip( random_batch["prompt"], policy_output_decoded, ref_output_decoded ) ], ) } ) self.state.log_history.pop() # Base evaluation initial_output = super().evaluation_loop( dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix ) return initial_output def log(self, logs: Dict[str, float]) -> None: """ Log `logs` on the various objects watching training, including stored metrics. Args: logs (`Dict[str, float]`): The values to log. """ # logs either has 'loss' or 'eval_loss' train_eval = "train" if "loss" in logs else "eval" # Add averaged stored metrics to logs for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] return super().log(logs) @wraps(Trainer.push_to_hub) def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: """ Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. """ kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs) return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)