Repository: borisdayma/dalle-mini Branch: main Commit: f0be4de61028 Files: 48 Total size: 392.9 KB Directory structure: gitextract_0gdudwyh/ ├── .gitattributes ├── .github/ │ ├── FUNDING.yml │ └── workflows/ │ ├── check_size.yml │ ├── pypi_release.yml │ ├── style.yml │ ├── sync_to_hub.yml.backup │ └── sync_to_hub_debug.yml ├── .gitignore ├── CITATION.cff ├── Docker/ │ ├── Dockerfile │ ├── README.md │ └── build_docker.sh ├── LICENSE ├── Makefile ├── README.md ├── app/ │ ├── gradio/ │ │ ├── app.py │ │ └── backend.py │ └── streamlit/ │ ├── app.py │ └── backend.py ├── pyproject.toml ├── run_docker_image.sh ├── setup.cfg ├── setup.py ├── src/ │ └── dalle_mini/ │ ├── __init__.py │ ├── data.py │ └── model/ │ ├── __init__.py │ ├── configuration.py │ ├── modeling.py │ ├── partitions.py │ ├── processor.py │ ├── text.py │ ├── tokenizer.py │ └── utils.py └── tools/ ├── dataset/ │ └── encode_dataset.ipynb ├── inference/ │ ├── inference_pipeline.ipynb │ └── run_infer_notebook.sh └── train/ ├── config/ │ ├── mega/ │ │ └── config.json │ ├── micro/ │ │ └── config.json │ ├── mini/ │ │ └── config.json │ └── mini_glu/ │ └── config.json ├── embeddings_retrain_preparation.ipynb ├── scalable_shampoo/ │ ├── README.md │ ├── distributed_shampoo.py │ ├── quantization_utils.py │ ├── sm3.py │ └── symmetric_matrices/ │ └── symmetric_matrices.py ├── sweep.yaml └── train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitattributes ================================================ *.bin.* filter=lfs diff=lfs merge=lfs -text *.lfs.* filter=lfs diff=lfs merge=lfs -text *.bin filter=lfs diff=lfs merge=lfs -text *.h5 filter=lfs diff=lfs merge=lfs -text *.tflite filter=lfs diff=lfs merge=lfs -text *.tar.gz filter=lfs diff=lfs merge=lfs -text *.ot filter=lfs diff=lfs merge=lfs -text *.onnx filter=lfs diff=lfs merge=lfs -text *.arrow filter=lfs diff=lfs merge=lfs -text *.ftz filter=lfs diff=lfs merge=lfs -text *.joblib filter=lfs diff=lfs merge=lfs -text *.model filter=lfs diff=lfs merge=lfs -text *.msgpack filter=lfs diff=lfs merge=lfs -text *.pb filter=lfs diff=lfs merge=lfs -text *.pt filter=lfs diff=lfs merge=lfs -text *.pth filter=lfs diff=lfs merge=lfs -text ================================================ FILE: .github/FUNDING.yml ================================================ github: [borisdayma] ================================================ FILE: .github/workflows/check_size.yml ================================================ name: Check file size on: pull_request: branches: [main] # to run this workflow manually from the Actions tab workflow_dispatch: jobs: sync-to-hub: runs-on: ubuntu-latest steps: - name: Check large files uses: ActionsDesk/lfs-warning@v2.0 with: filesizelimit: 10485760 # = 10MB, so we can sync to HF spaces ================================================ FILE: .github/workflows/pypi_release.yml ================================================ # This workflow uses actions that are not certified by GitHub. # They are provided by a third-party and are governed by # separate terms of service, privacy policy, and support # documentation. name: Upload Python Package on: release: types: [published] jobs: deploy: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v3 with: python-version: "3.x" - name: Install dependencies run: | python -m pip install --upgrade pip pip install build - name: Build package run: python -m build - name: Publish package uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} ================================================ FILE: .github/workflows/style.yml ================================================ name: Lint on: push: branches: [main] pull_request: branches: [main] jobs: lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - uses: psf/black@stable - uses: actions/setup-python@v2 with: python-version: 3.9 - name: Install requirements run: pip install ".[dev]" - uses: jamescurtin/isort-action@master ================================================ FILE: .github/workflows/sync_to_hub.yml.backup ================================================ name: Sync to Hugging Face hub - Obsolete to avoid app disruptions on: push: branches: [main] # to run this workflow manually from the Actions tab workflow_dispatch: jobs: sync-to-hub: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 with: fetch-depth: 0 - name: Push to hub env: HF_TOKEN: ${{ secrets.HF_TOKEN }} run: git push https://boris:$HF_TOKEN@huggingface.co/spaces/dalle-mini/dalle-mini main ================================================ FILE: .github/workflows/sync_to_hub_debug.yml ================================================ name: Deploy to debug app on: # to run this workflow manually from the Actions tab workflow_dispatch: jobs: sync-to-hub-debug: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 with: fetch-depth: 0 - name: Push to hub env: HF_TOKEN: ${{ secrets.HF_TOKEN }} run: git push --force https://boris:$HF_TOKEN@huggingface.co/spaces/dalle-mini/dalle-mini-debug +HEAD:main ================================================ FILE: .gitignore ================================================ __pycache__ .ipynb_checkpoints .streamlit wandb/ *.egg-info/ jax_cache/ ================================================ FILE: CITATION.cff ================================================ # YAML 1.2 --- abstract: "DALL·E mini is a JAX/Flax reimplementation of OpenAI's DALL·E that requires much smaller hardware resources. By simplifying the architecture and model memory requirements, as well as leveraging open-source code and pre-trained models, we were able to create a model that is 27 times smaller than the original DALL·E and train it on a single TPU v3-8 for only 3 days. DALL·E mini achieves impressive results, albeit of a lower quality than the original system. It can be used for exploration and further experimentation on commodity hardware." authors: - family-names: Dayma given-names: Boris - family-names: Patil given-names: Suraj - family-names: Cuenca given-names: Pedro - family-names: Saifullah given-names: Khalid - family-names: Abraham given-names: Tanishq - family-names: "Lê Khắc" given-names: "Phúc" - family-names: Melas given-names: Luke - family-names: Ghosh given-names: Ritobrata cff-version: "1.1.0" date-released: 2021-07-29 identifiers: keywords: - dalle - "text-to-image generation" - transformer - "zero-shot" - JAX license: "Apache-2.0" doi: 10.5281/zenodo.5146400 message: "If you use this project, please cite it using these metadata." repository-code: "https://github.com/borisdayma/dalle-mini" title: "DALL·E Mini" version: "v0.1-alpha" ... ================================================ FILE: Docker/Dockerfile ================================================ FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04 RUN apt-get update && apt-get install -y \ git \ python3 \ python3-pip \ && rm -rf /var/lib/apt/lists/* RUN pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ && pip install -q \ git+https://github.com/borisdayma/dalle-mini.git \ git+https://github.com/patil-suraj/vqgan-jax.git RUN pip install jupyter WORKDIR /workspace ================================================ FILE: Docker/README.md ================================================ # Running Dalle-mini With Docker This folder contains the Dockerfile needed to build a Docker image that can easily run Dalle-mini. ## Inference Steps to run inference with Dalle-mini are as follows: 1. Build the docker image with ```dalle-mini/Docker/build_docker.sh``` 2. Run the container with ```dalle-mini/run_docker_image.sh``` 3. Navigate to ```/workspace/tools/inference/``` and run ```run_infer_notebook.sh``` 4. Click the Jupyter Notebook link and run through the notebook. ### Inference Video Tutorial Alteratively check out a video tutorial on how to run Dalle-mini on [Linux](https://www.youtube.com/watch?v=eWpzLIa6v9E) and [Windows](https://www.youtube.com/watch?v=OqEuEe-xSKk) ================================================ FILE: Docker/build_docker.sh ================================================ docker build . -t dalle-mini:latest ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2021 The DALL·E mini Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: Makefile ================================================ .PHONY: style style: black . isort . ================================================ FILE: README.md ================================================ # DALL·E Mini ## How to use it? You can use the model on [🖍️ craiyon](https://www.craiyon.com/) ## How does it work? Refer to our reports: * [DALL·E mini - Generate Images from Any Text Prompt](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini-Generate-images-from-any-text-prompt--VmlldzoyMDE4NDAy) * [DALL·E mini - Explained](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mini-Explained-with-Demo--Vmlldzo4NjIxODA) * [DALL·E mega - Training Journal](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mega-Training-Journal--VmlldzoxODMxMDI2) ## Development ### Dependencies Installation For inference only, use `pip install dalle-mini`. For development, clone the repo and use `pip install -e ".[dev]"`. Before making a PR, check style with `make style`. You can experiment with the pipeline step by step through our [`inference pipeline notebook`](tools/inference/inference_pipeline.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb) ### Training of DALL·E mini Use [`tools/train/train.py`](tools/train/train.py). You can also adjust the [sweep configuration file](https://docs.wandb.ai/guides/sweeps) if you need to perform a hyperparameter search. ## FAQ ### Where to find the latest models? Trained models are on 🤗 Model Hub: * [VQGAN-f16-16384](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384) for encoding/decoding images * [DALL·E mini](https://huggingface.co/dalle-mini/dalle-mini) or [DALL·E mega](https://huggingface.co/dalle-mini/dalle-mega) for generating images from a text prompt ### Where does the logo come from? The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL·E to illustrate the model's capabilities. Having successful predictions on this prompt represents a big milestone for us. ## Contributing Join the community on the [LAION Discord](https://discord.gg/xBPBXfcFHd). Any contribution is welcome, from reporting issues to proposing fixes/improvements or testing the model with cool prompts! You can also use these great projects from the community: * spin off your own app with [DALL-E Playground repository](https://github.com/saharmor/dalle-playground) (thanks [Sahar](https://twitter.com/theaievangelist)) * try [DALL·E Flow](https://github.com/jina-ai/dalle-flow) project for generating, diffusion, and upscaling in a Human-in-the-Loop workflow (thanks [Han Xiao](https://github.com/hanxiao)) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jina-ai/dalle-flow/blob/main/client.ipynb) * run on [Replicate](https://replicate.com/borisdayma/dalle-mini), in the browser or via API ## Acknowledgements * 🤗 Hugging Face for organizing [the FLAX/JAX community week](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects) * Google [TPU Research Cloud (TRC) program](https://sites.research.google/trc/) for providing computing resources * [Weights & Biases](https://wandb.com/) for providing the infrastructure for experiment tracking and model management ## Authors & Contributors DALL·E mini was initially developed by: * [Boris Dayma](https://github.com/borisdayma) * [Suraj Patil](https://github.com/patil-suraj) * [Pedro Cuenca](https://github.com/pcuenca) * [Khalid Saifullah](https://github.com/khalidsaifullaah) * [Tanishq Abraham](https://github.com/tmabraham) * [Phúc Lê Khắc](https://github.com/lkhphuc) * [Luke Melas](https://github.com/lukemelas) * [Ritobrata Ghosh](https://github.com/ghosh-r) Many thanks to the people who helped make it better: * the [DALLE-Pytorch](https://discord.gg/xBPBXfcFHd) and [EleutherAI](https://www.eleuther.ai/) communities for testing and exchanging cool ideas * [Rohan Anil](https://github.com/rohan-anil) for adding Distributed Shampoo optimizer and always giving great suggestions * [Phil Wang](https://github.com/lucidrains) has provided a lot of cool implementations of transformer variants and gives interesting insights with [x-transformers](https://github.com/lucidrains/x-transformers) * [Katherine Crowson](https://github.com/crowsonkb) for [super conditioning](https://twitter.com/RiversHaveWings/status/1478093658716966912) * the [Gradio team](https://gradio.app/) made an amazing UI for our app ## Citing DALL·E mini If you find DALL·E mini useful in your research or wish to refer, please use the following BibTeX entry. ```text @misc{Dayma_DALL·E_Mini_2021, author = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata}, doi = {10.5281/zenodo.5146400}, month = {7}, title = {DALL·E Mini}, url = {https://github.com/borisdayma/dalle-mini}, year = {2021} } ``` ## References Original DALL·E from "[Zero-Shot Text-to-Image Generation](https://arxiv.org/abs/2102.12092)" with image quantization from "[Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)". Image encoder from "[Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841v2)". Sequence to sequence model based on "[BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461v1)" with implementation of a few variants: * "[GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202)" * "[Deepnet: Scaling Transformers to 1,000 Layers](https://arxiv.org/abs/2203.00555)" * "[NormFormer: Improved Transformer Pretraining with Extra Normalization](https://arxiv.org/abs/2110.09456)" * "[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)" * "[CogView: Mastering Text-to-Image Generation via Transformers](https://arxiv.org/abs/2105.13290v2)" * "[Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)" * "[Sinkformers: Transformers with Doubly Stochastic Attention](https://arxiv.org/abs/2110.11773)" * "[Foundation Transformers](https://arxiv.org/abs/2210.06423) Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization for Deep Learning](https://arxiv.org/abs/2002.09018)". ### Citations ```text @misc{ title={Zero-Shot Text-to-Image Generation}, author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever}, year={2021}, eprint={2102.12092}, archivePrefix={arXiv}, primaryClass={cs.CV} } ``` ```text @misc{ title={Learning Transferable Visual Models From Natural Language Supervision}, author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever}, year={2021}, eprint={2103.00020}, archivePrefix={arXiv}, primaryClass={cs.CV} } ``` ```text @misc{ title={Taming Transformers for High-Resolution Image Synthesis}, author={Patrick Esser and Robin Rombach and Björn Ommer}, year={2021}, eprint={2012.09841}, archivePrefix={arXiv}, primaryClass={cs.CV} } ``` ```text @misc{ title={BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension}, author={Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and Abdelrahman Mohamed and Omer Levy and Ves Stoyanov and Luke Zettlemoyer}, year={2019}, eprint={1910.13461}, archivePrefix={arXiv}, primaryClass={cs.CL} } ``` ```text @misc{ title={Scalable Second Order Optimization for Deep Learning}, author={Rohan Anil and Vineet Gupta and Tomer Koren and Kevin Regan and Yoram Singer}, year={2021}, eprint={2002.09018}, archivePrefix={arXiv}, primaryClass={cs.LG} } ``` ```text @misc{ title={GLU Variants Improve Transformer}, author={Noam Shazeer}, year={2020}, url={https://arxiv.org/abs/2002.05202} } ``` ```text @misc{ title={DeepNet: Scaling transformers to 1,000 layers}, author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Zhang, Dongdong and Wei, Furu}, year={2022}, eprint={2203.00555} archivePrefix={arXiv}, primaryClass={cs.LG} } ``` ```text @misc{ title={NormFormer: Improved Transformer Pretraining with Extra Normalization}, author={Sam Shleifer and Jason Weston and Myle Ott}, year={2021}, eprint={2110.09456}, archivePrefix={arXiv}, primaryClass={cs.CL} } ``` ```text @inproceedings{ title={Swin Transformer V2: Scaling Up Capacity and Resolution}, author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo}, booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)}, year={2022} } ``` ```text @misc{ title = {CogView: Mastering Text-to-Image Generation via Transformers}, author = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang}, year = {2021}, eprint = {2105.13290}, archivePrefix = {arXiv}, primaryClass = {cs.CV} } ``` ```text @misc{ title = {Root Mean Square Layer Normalization}, author = {Biao Zhang and Rico Sennrich}, year = {2019}, eprint = {1910.07467}, archivePrefix = {arXiv}, primaryClass = {cs.LG} } ``` ```text @misc{ title = {Sinkformers: Transformers with Doubly Stochastic Attention}, url = {https://arxiv.org/abs/2110.11773}, author = {Sander, Michael E. and Ablin, Pierre and Blondel, Mathieu and Peyré, Gabriel}, publisher = {arXiv}, year = {2021}, } ``` ```text @misc{ title = {Smooth activations and reproducibility in deep networks}, url = {https://arxiv.org/abs/2010.09931}, author = {Shamir, Gil I. and Lin, Dong and Coviello, Lorenzo}, publisher = {arXiv}, year = {2020}, } ``` ```text @misc{ title = {Foundation Transformers}, url = {https://arxiv.org/abs/2210.06423}, author = {Wang, Hongyu and Ma, Shuming and Huang, Shaohan and Dong, Li and Wang, Wenhui and Peng, Zhiliang and Wu, Yu and Bajaj, Payal and Singhal, Saksham and Benhaim, Alon and Patra, Barun and Liu, Zhun and Chaudhary, Vishrav and Song, Xia and Wei, Furu}, publisher = {arXiv}, year = {2022}, } ``` ================================================ FILE: app/gradio/app.py ================================================ #!/usr/bin/env python # coding: utf-8 import os import gradio as gr from backend import get_images_from_backend block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }") backend_url = os.environ["BACKEND_SERVER"] + "/generate" def infer(prompt): response = get_images_from_backend(prompt, backend_url) return response["images"] with block: gr.Markdown("

DALL·E mini

") gr.Markdown( "DALL·E mini is an AI model that generates images from any prompt you give!" ) with gr.Group(): with gr.Box(): with gr.Row().style(mobile_collapse=False, equal_height=True): text = gr.Textbox( label="Enter your prompt", show_label=False, max_lines=1 ).style( border=(True, False, True, True), margin=False, rounded=(True, False, False, True), container=False, ) btn = gr.Button("Run").style( margin=False, rounded=(False, True, True, False), ) gallery = gr.Gallery(label="Generated images", show_label=False).style( grid=[3], height="auto" ) text.submit(infer, inputs=text, outputs=gallery) btn.click(infer, inputs=text, outputs=gallery) gr.Markdown( """___

Created by Boris Dayma et al. 2021-2022
GitHub | Project Report

""" ) block.launch(enable_queue=False) ================================================ FILE: app/gradio/backend.py ================================================ # Client requests to Dalle-Mini Backend server import base64 from io import BytesIO import requests from PIL import Image class ServiceError(Exception): def __init__(self, status_code): self.status_code = status_code def get_images_from_backend(prompt, backend_url): r = requests.post(backend_url, json={"prompt": prompt}) if r.status_code == 200: json = r.json() images = json["images"] images = [Image.open(BytesIO(base64.b64decode(img))) for img in images] version = json.get("version", "unknown") return {"images": images, "version": version} else: raise ServiceError(r.status_code) def get_model_version(url): r = requests.get(url) if r.status_code == 200: version = r.json()["version"] return version else: raise ServiceError(r.status_code) ================================================ FILE: app/streamlit/app.py ================================================ #!/usr/bin/env python # coding: utf-8 import streamlit as st from backend import ServiceError, get_images_from_backend st.sidebar.markdown( """

""", unsafe_allow_html=True, ) st.sidebar.markdown( """ ___

DALL·E mini is an AI model that generates images from any prompt you give!

Created by Boris Dayma et al. 2021-2022
GitHub | Project Report

""", unsafe_allow_html=True, ) st.header("DALL·E mini") st.subheader("Generate images from text") prompt = st.text_input("What do you want to see?") DEBUG = False if prompt != "": container = st.empty() container.markdown( f"""
Predictions may take up to 5mn under high load. Please stand by. """, unsafe_allow_html=True, ) try: backend_url = st.secrets["BACKEND_SERVER"] + "/generate" response = get_images_from_backend(prompt, backend_url) selected = response["images"] version = response["version"] margin = 0.1 # for better position of zoom in arrow n_columns = 3 cols = st.columns([1] + [margin, 1] * (n_columns - 1)) for i, img in enumerate(selected): cols[(i % n_columns) * 2].image(img) container.markdown(f"**{prompt}**") # st.sidebar.markdown( # f"
{version}
", unsafe_allow_html=True # ) # st.markdown( # f""" # These results have been obtained using model `{version}` from [an ongoing training run](https://wandb.ai/dalle-mini/dalle-mini/runs/mheh9e55). # """ # ) st.button("Again!", key="again_button") except ServiceError as error: container.text(f"Service unavailable, status: {error.status_code}") except KeyError: if DEBUG: container.markdown( """ **Error: BACKEND_SERVER unset** Please, create a file called `.streamlit/secrets.toml` inside the app's folder and include a line to configure the server URL: ``` BACKEND_SERVER="" ``` """ ) else: container.markdown( "Error -5, please try again or [report it](mailto:pcuenca-dalle@guenever.net)." ) ================================================ FILE: app/streamlit/backend.py ================================================ # Client requests to Dalle-Mini Backend server import base64 from io import BytesIO import requests from PIL import Image class ServiceError(Exception): def __init__(self, status_code): self.status_code = status_code def get_images_from_backend(prompt, backend_url): r = requests.post(backend_url, json={"prompt": prompt}) if r.status_code == 200: json = r.json() images = json["images"] images = [Image.open(BytesIO(base64.b64decode(img))) for img in images] version = json.get("version", "unknown") return {"images": images, "version": version} else: raise ServiceError(r.status_code) def get_model_version(url): r = requests.get(url) if r.status_code == 200: version = r.json()["version"] return version else: raise ServiceError(r.status_code) ================================================ FILE: pyproject.toml ================================================ [tool.isort] profile = "black" ================================================ FILE: run_docker_image.sh ================================================ #!/bin/bash # This script is used to run the docker image. Change or remove GPU flag if you dont have nvidia-docker or the needed GPUs docker run --rm --name dallemini -it -p 8888:8888 --gpus all -v "${PWD}":/workspace dalle-mini:latest ================================================ FILE: setup.cfg ================================================ [metadata] name = dalle-mini version = attr: dalle_mini.__version__ author = Boris Dayma et al. author_email = boris.dayma@gmail.com description = DALL·E mini - Generate images from a text prompt long_description = file: README.md long_description_content_type = text/markdown url = https://github.com/borisdayma/dalle-mini project_urls = Bug Tracker = https://github.com/borisdayma/dalle-mini/issues classifiers = Programming Language :: Python :: 3 License :: OSI Approved :: Apache Software License Operating System :: OS Independent Topic :: Scientific/Engineering :: Artificial Intelligence Development Status :: 3 - Alpha Intended Audience :: Developers [options] package_dir = =src packages = find: python_requires = >=3.6 install_requires = transformers==4.25.1 einops unidecode ftfy emoji pillow jax==0.3.25 flax==0.6.3 orbax==0.0.23 wandb [options.extras_require] dev = tqdm optax braceexpand datasets[streaming] black[jupyter] isort [options.packages.find] where = src ================================================ FILE: setup.py ================================================ from setuptools import setup if __name__ == "__main__": setup() ================================================ FILE: src/dalle_mini/__init__.py ================================================ __version__ = "0.1.5" from .model import DalleBart, DalleBartProcessor ================================================ FILE: src/dalle_mini/data.py ================================================ import random from dataclasses import dataclass, field from functools import partial from pathlib import Path import jax import jax.numpy as jnp import numpy as np from braceexpand import braceexpand from datasets import Dataset, load_dataset from .model.text import TextNormalizer @dataclass class Dataset: dataset_repo_or_path: str train_file: str = None validation_file: str = None streaming: bool = True use_auth_token: bool = False text_column: str = "caption" encoding_column: str = "encoding" max_train_samples: int = None max_eval_samples: int = None preprocessing_num_workers: int = None overwrite_cache: bool = False do_train: bool = False do_eval: bool = True seed_dataset: int = None shard_by_host: bool = False blank_caption_prob: float = 0.0 clip_score_column: str = "clip_score" min_clip_score: float = None max_clip_score: float = None filter_column: str = None filter_value: str = None multi_eval_ds: bool = False train_dataset: Dataset = field(init=False) eval_dataset: Dataset = field(init=False) other_eval_datasets: list = field(init=False) rng_dataset: jnp.ndarray = field(init=False) multi_hosts: bool = field(init=False) def __post_init__(self): if self.seed_dataset is None: # create a random seed self.seed_dataset = random.randint(0, 2**32 - 1) # set numpy rng self.np_rng = np.random.default_rng(self.seed_dataset) self.multi_hosts = jax.process_count() > 1 # feed blank captions only in streaming mode for now # otherwise dataset could be cached with same blanked captions if self.blank_caption_prob: assert ( self.streaming is True ), "blank_caption_prob can only be used in streaming mode" # define data_files if self.train_file is not None or self.validation_file is not None: # accept braceexpand notation for k in ["train_file", "validation_file"]: f = getattr(self, k) if isinstance(f, str): setattr(self, k, list(braceexpand(f))) # for list of files, split training data shards by host if ( isinstance(self.train_file, list) and self.multi_hosts and self.shard_by_host ): self.train_file = self.train_file[ jax.process_index() :: jax.process_count() ] data_files = { "train": self.train_file, "validation": self.validation_file, } else: data_files = None # multiple validation datasets if self.multi_eval_ds: assert Path( self.dataset_repo_or_path ).is_dir(), f"{self.dataset_repo_or_path} is not a directory, required for multi_eval_ds" data_files = { split.name: [str(f) for f in split.glob("*.parquet")] for split in Path(self.dataset_repo_or_path).glob("*") } # rename "valid" to "validation" if present for consistency if "valid" in data_files: data_files["validation"] = data_files["valid"] del data_files["valid"] self.dataset_repo_or_path = "parquet" # load dataset dataset = load_dataset( self.dataset_repo_or_path, data_files=data_files, streaming=self.streaming, use_auth_token=self.use_auth_token, ) if self.do_train: if "train" not in dataset: raise ValueError("Training requires a training dataset") self.train_dataset = dataset["train"] if self.max_train_samples is not None: self.train_dataset = ( self.train_dataset.take(self.max_train_samples) if self.streaming else self.train_dataset.select(range(self.max_train_samples)) ) if self.do_eval: if "validation" not in dataset: raise ValueError("Evaluating requires a validation dataset") self.eval_dataset = dataset["validation"] if self.max_eval_samples is not None: self.eval_dataset = ( self.eval_dataset.take(self.max_eval_samples) if self.streaming else self.eval_dataset.select(range(self.max_eval_samples)) ) # other eval datasets other_eval_splits = dataset.keys() - {"train", "validation"} self.other_eval_datasets = { split: dataset[split] for split in other_eval_splits } def preprocess(self, tokenizer, config): # get required config variables decoder_start_token_id = config.decoder_start_token_id normalize_text = config.normalize_text max_length = config.max_text_length if self.streaming: # we need to shuffle early in streaming mode if hasattr(self, "train_dataset"): self.train_dataset = self.train_dataset.shuffle( buffer_size=5000, seed=self.seed_dataset ) else: self.rng_dataset = jax.random.PRNGKey(self.seed_dataset) # filter data partial_filter_function = partial( filter_function, filter_column=self.filter_column, filter_value=self.filter_value, clip_score_column=self.clip_score_column, min_clip_score=self.min_clip_score, max_clip_score=self.max_clip_score, ) for ds in ["train_dataset", "eval_dataset"]: if hasattr(self, ds): setattr( self, ds, ( getattr(self, ds).filter(partial_filter_function) if self.streaming else getattr(self, ds).filter( partial_filter_function, num_proc=self.preprocessing_num_workers, load_from_cache_file=not self.overwrite_cache, desc="Filtering datasets", ) ), ) if hasattr(self, "other_eval_datasets"): self.other_eval_datasets = { split: ( ds.filter(partial_filter_function) if self.streaming else ds.filter( partial_filter_function, num_proc=self.preprocessing_num_workers, load_from_cache_file=not self.overwrite_cache, desc="Filtering datasets", ) ) for split, ds in self.other_eval_datasets.items() } # normalize text if normalize_text: text_normalizer = TextNormalizer() partial_normalize_function = partial( normalize_function, text_column=self.text_column, text_normalizer=text_normalizer, ) for ds in ["train_dataset", "eval_dataset"]: if hasattr(self, ds): setattr( self, ds, ( getattr(self, ds).map(partial_normalize_function) if self.streaming else getattr(self, ds).map( partial_normalize_function, num_proc=self.preprocessing_num_workers, load_from_cache_file=not self.overwrite_cache, desc="Normalizing datasets", ) ), ) if hasattr(self, "other_eval_datasets"): self.other_eval_datasets = { split: ( ds.map(partial_normalize_function) if self.streaming else ds.map( partial_normalize_function, num_proc=self.preprocessing_num_workers, load_from_cache_file=not self.overwrite_cache, desc="Normalizing datasets", ) ) for split, ds in self.other_eval_datasets.items() } # blank captions if self.blank_caption_prob: partial_blank_caption_function = partial( blank_caption_function, text_column=self.text_column, blank_caption_prob=self.blank_caption_prob, rng=self.np_rng, ) if hasattr(self, "train_dataset"): self.train_dataset = ( self.train_dataset.map(partial_blank_caption_function) if self.streaming else self.train_dataset.map( partial_blank_caption_function, num_proc=None if self.seed_dataset else self.preprocessing_num_workers, load_from_cache_file=False, desc="Blanking some captions", ) ) # preprocess partial_preprocess_function = partial( preprocess_function, tokenizer=tokenizer, text_column=self.text_column, encoding_column=self.encoding_column, max_length=max_length, decoder_start_token_id=decoder_start_token_id, ) for ds in ["train_dataset", "eval_dataset"]: if hasattr(self, ds): setattr( self, ds, ( getattr(self, ds).map( partial_preprocess_function, batched=True, remove_columns=[ self.text_column, self.encoding_column, ], ) if self.streaming else getattr(self, ds).map( partial_preprocess_function, batched=True, remove_columns=getattr(ds, "column_names"), num_proc=self.preprocessing_num_workers, load_from_cache_file=not self.overwrite_cache, desc="Preprocessing datasets", ) ), ) if hasattr(self, "other_eval_datasets"): self.other_eval_datasets = { split: ( ds.map( partial_preprocess_function, batched=True, remove_columns=[ self.text_column, self.encoding_column, ], ) if self.streaming else ds.map( partial_preprocess_function, batched=True, remove_columns=getattr(ds, "column_names"), num_proc=self.preprocessing_num_workers, load_from_cache_file=not self.overwrite_cache, desc="Preprocessing datasets", ) ) for split, ds in self.other_eval_datasets.items() } def dataloader(self, split, batch_size, epoch=None): def _dataloader_datasets_non_streaming( dataset: Dataset, rng: jax.random.PRNGKey = None, ): """ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. Shuffle batches if rng is set. """ steps_per_epoch = len(dataset) // batch_size if rng is not None: batch_idx = jax.random.permutation(rng, len(dataset)) else: batch_idx = jnp.arange(len(dataset)) batch_idx = batch_idx[ : steps_per_epoch * batch_size ] # Skip incomplete batch. batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) for idx in batch_idx: batch = dataset[idx] batch = {k: jnp.array(v) for k, v in batch.items()} yield batch def _dataloader_datasets_streaming( dataset: Dataset, epoch: int, ): keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"] batch = {k: [] for k in keys} first_loop = True # stop after one loop in some cases while (self.multi_hosts and split == "train") or first_loop: # in multi-host, we run forever (no epoch) as hosts need to stop # at the same time and training data may not be split equally # For validation data we put the entire batch on each host and then # keep only the one specific to each host (could be improved but not necessary) if epoch is not None: assert split == "train" # reshuffle training data at each epoch dataset.set_epoch(epoch) epoch += 1 for item in dataset: for k in keys: batch[k].append(item[k]) if len(batch[keys[0]]) == batch_size: batch = {k: jnp.array(v) for k, v in batch.items()} yield batch batch = {k: [] for k in keys} first_loop = False if split == "train": ds = self.train_dataset elif split == "eval": ds = self.eval_dataset else: ds = self.other_eval_datasets[split] if self.streaming: return _dataloader_datasets_streaming(ds, epoch) else: if split == "train": self.rng_dataset, input_rng = jax.random.split(self.rng_dataset) return _dataloader_datasets_non_streaming(ds, input_rng) @property def length(self): len_train_dataset, len_eval_dataset = None, None if self.streaming: # we don't know the length, let's just assume max_samples if defined if self.max_train_samples is not None: len_train_dataset = self.max_train_samples if self.max_eval_samples is not None: len_eval_dataset = self.max_eval_samples else: len_train_dataset = ( len(self.train_dataset) if hasattr(self, "train_dataset") else None ) len_eval_dataset = ( len(self.eval_dataset) if hasattr(self, "eval_dataset") else None ) return len_train_dataset, len_eval_dataset def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int): """ Shift input ids one token to the right. """ shifted_input_ids = np.zeros(input_ids.shape) shifted_input_ids[:, 1:] = input_ids[:, :-1] shifted_input_ids[:, 0] = decoder_start_token_id return shifted_input_ids def blank_caption_function(example, text_column, blank_caption_prob, rng=None): if ( blank_caption_prob and (rng.random() if rng is not None else np.random.random()) < blank_caption_prob ): example[text_column] = "" return example def normalize_function(example, text_column, text_normalizer): example[text_column] = text_normalizer(example[text_column]) return example def filter_function( example, min_clip_score, max_clip_score, clip_score_column, filter_column, filter_value, ): if min_clip_score is not None and example[clip_score_column] < min_clip_score: return False if max_clip_score is not None and example[clip_score_column] > max_clip_score: return False if filter_column is not None and example[filter_column] != filter_value: return False return True def preprocess_function( examples, tokenizer, text_column, encoding_column, max_length, decoder_start_token_id, ): inputs = examples[text_column] # Setting padding="max_length" as we need fixed length inputs for jitted functions model_inputs = tokenizer( inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="np", ) # set up targets # Note: labels correspond to our target indices # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token) labels = examples[encoding_column] labels = np.asarray(labels) # We need the labels, in addition to the decoder_input_ids, for the compute_loss function model_inputs["labels"] = labels # In our case, this prepends the bos token and removes the last one decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id) model_inputs["decoder_input_ids"] = decoder_input_ids return model_inputs ================================================ FILE: src/dalle_mini/model/__init__.py ================================================ from .configuration import DalleBartConfig from .modeling import DalleBart from .partitions import set_partitions from .processor import DalleBartProcessor from .tokenizer import DalleBartTokenizer ================================================ FILE: src/dalle_mini/model/configuration.py ================================================ # coding=utf-8 # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ DalleBart model configuration """ import warnings from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging from .utils import PretrainedFromWandbMixin logger = logging.get_logger(__name__) class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig): model_type = "dallebart" keys_to_ignore_at_inference = ["past_key_values"] attribute_map = { "num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model", } def __init__( self, normalize_text=False, encoder_vocab_size=50264, image_vocab_size=16384, # encoded image token space image_length=256, # number of encoded tokens max_text_length=64, # max number of text tokens encoder_layers=12, encoder_ffn_dim=4096, encoder_attention_heads=16, decoder_layers=12, decoder_ffn_dim=4096, decoder_attention_heads=16, activation_function="gelu", d_model=1024, dropout=0.1, attention_dropout=0.0, activation_dropout=0.0, init_std=0.02, scale_embedding=False, gradient_checkpointing=True, use_scan=None, use_cache=True, is_encoder_decoder=True, forced_eos_token_id=None, tie_word_embeddings=False, # different modalities and sizes do_sample=True, # transformer variants use_bias=False, # use bias in attention and dense layers (except for lm_head) ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm" ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln), "subln" use_head_scale=False, # used in NormFormer use_cosine_attention=False, # used in Swin v2 tau_init=0.05, # used only in cosine attention (Swin v2) use_absolute_position_embeddings=True, # default use_swin_position_embeddings=False, # used in Swin v1/v2 use_deepnet_scaling=False, # used in Deepnet use_subln_init=False, use_glu=True, # "GLU Variants Improve Transformer" use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation" sinkhorn_iters=1, # used in SinkFormers use_final_ln_encoder=True, # final layer normalization in encoder use_final_ln_decoder=True, # final layer normalization in decoder # parameters that should not be necessary but could affect results force_ln_scale=False, # force scale in layernorm even when followed by dense layers **kwargs, ): # text normalizer self.normalize_text = normalize_text # transformer variants self.use_bias = use_bias assert ln_type in [ "rmsnorm", "layernorm", ], "ln_type must be 'rmsnorm' or 'layernorm'" self.ln_type = ln_type if ln_positions == "deepnet": ln_positions = "postln" assert ln_positions in [ "normformer", "swinv2", "cogview", "postln", "preln", "subln", ], "ln_positions must be 'normformer', 'swinv2', 'cogview', 'postln', 'preln', 'subln'" self.use_head_scale = use_head_scale assert use_alibi is False, "use_alibi is not supported yet" self.ln_positions = ln_positions self.use_cosine_attention = use_cosine_attention self.tau_init = tau_init self.use_absolute_position_embeddings = use_absolute_position_embeddings self.use_swin_position_embeddings = use_swin_position_embeddings self.use_deepnet_scaling = use_deepnet_scaling self.use_subln_init = use_subln_init self.use_glu = use_glu self.use_alibi = use_alibi self.sinkhorn_iters = sinkhorn_iters if ln_positions == "postln": assert ( use_final_ln_encoder ), "use_final_ln_encoder must be True when ln_positions is 'postln'" assert ( use_final_ln_decoder ), "use_final_ln_decoder must be True when ln_positions is 'postln'" self.use_final_ln_encoder = use_final_ln_encoder self.use_final_ln_decoder = use_final_ln_decoder self.force_ln_scale = force_ln_scale # common parameters self.encoder_vocab_size = encoder_vocab_size self.image_vocab_size = image_vocab_size self.image_length = image_length self.max_text_length = max_text_length self.d_model = d_model self.encoder_ffn_dim = encoder_ffn_dim self.encoder_layers = encoder_layers self.encoder_attention_heads = encoder_attention_heads self.decoder_ffn_dim = decoder_ffn_dim self.decoder_layers = decoder_layers self.decoder_attention_heads = decoder_attention_heads self.dropout = dropout self.attention_dropout = attention_dropout self.activation_dropout = activation_dropout self.activation_function = activation_function self.init_std = init_std self.use_cache = use_cache self.gradient_checkpointing = gradient_checkpointing # all layers are the same in most configurations self.use_scan = use_scan if use_scan is not None else ln_positions != "swinv2" assert not ( self.use_scan and ln_positions == "swinv2" ), "scan cannot be used with 'swinv2'" self.scale_embedding = ( scale_embedding # scale factor will be sqrt(d_model) if True ) # special token id's are appended to vocab if not provided decoder_start_token_id = kwargs.pop("decoder_start_token_id", image_vocab_size) bos_token_id = kwargs.pop("bos_token_id", image_vocab_size) pad_token_id = kwargs.pop("pad_token_id", image_vocab_size) eos_token_id = kwargs.pop("eos_token_id", image_vocab_size) # we generate to image_length + 1 (for bos) by default min_length = kwargs.pop("min_length", image_length + 1) max_length = kwargs.pop("max_length", image_length + 1) super().__init__( # args required in parent class is_encoder_decoder=is_encoder_decoder, tie_word_embeddings=tie_word_embeddings, forced_eos_token_id=forced_eos_token_id, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id, pad_token_id=pad_token_id, eos_token_id=eos_token_id, min_length=min_length, max_length=max_length, do_sample=do_sample, **kwargs, ) # ensure backward compatibility for BART CNN models if self.forced_bos_token_id is None and kwargs.get( "force_bos_token_to_be_generated", False ): self.forced_bos_token_id = self.bos_token_id warnings.warn( f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions." "The config can simply be saved and uploaded again to be fixed." ) ================================================ FILE: src/dalle_mini/model/modeling.py ================================================ # coding=utf-8 # Copyright 2021-2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team and & DALL·E Mini team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ DalleBart model. """ import math from functools import partial from typing import Any, Dict, Optional, Tuple import flax import flax.linen as nn import jax import jax.numpy as jnp from einops import rearrange from flax.core.frozen_dict import unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen import partitioning as nn_partitioning from flax.linen.linear import PrecisionLike from flax.traverse_util import flatten_dict, unflatten_dict from jax import custom_jvp, lax from jax.random import PRNGKey from transformers.modeling_flax_outputs import ( FlaxBaseModelOutput, FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput, ) from transformers.modeling_flax_utils import ACT2FN from transformers.models.bart.modeling_flax_bart import ( FlaxBartAttention, FlaxBartForConditionalGeneration, FlaxBartForConditionalGenerationModule, FlaxBartModule, ) from transformers.utils import ModelOutput, logging from .configuration import DalleBartConfig from .utils import PretrainedFromWandbMixin logger = logging.get_logger(__name__) remat = nn_partitioning.remat def smelu(beta: Any = 1.0): """ Implementation of "Real World Large Scale Recommendation Systems Reproducibility and Smooth Activations" https://arxiv.org/abs/2202.06499 """ @custom_jvp @jax.jit def _smelu(x: Any) -> Any: x = jnp.where(x <= -beta, 0.0, x) return jnp.where(x >= beta, x, jnp.square(x + beta) / (4 * beta)) _smelu.defjvps( lambda g, ans, x: lax.select( x == -beta, lax.full_like(g, 0), lax.select(x == beta, lax.full_like(g, 1), g), ) ) return _smelu ACT2FN.update({"smelu": smelu()}) # deepnet initialization def deepnet_init(init_std, gain=1): init = jax.nn.initializers.normal(init_std) def _init(*args, **kwargs): return gain * init(*args, **kwargs) return _init # deepnet gain deepnet_gain = { "encoder": { "alpha": lambda config: 0.81 * (config.encoder_layers**4 * config.decoder_layers) ** 0.0625, "beta": lambda config: 0.87 * (config.encoder_layers**4 * config.decoder_layers) ** -0.0625, }, "decoder": { "alpha": lambda config: (3 * config.decoder_layers) ** 0.25, "beta": lambda config: (12 * config.decoder_layers) ** -0.25, }, } # subln gain subln_gain = { "encoder": lambda config: math.sqrt( 1.0 / 3.0 * math.log(3 * config.decoder_layers) * math.log(2 * config.encoder_layers) ), "decoder": lambda config: math.sqrt(math.log(3 * config.decoder_layers)), } class RMSNorm(nn.Module): """ From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467 Adapted from flax.linen.LayerNorm """ epsilon: float = 1e-6 dtype: Any = jnp.float32 param_dtype: Any = jnp.float32 use_scale: bool = True scale_init: Any = jax.nn.initializers.ones @nn.compact def __call__(self, x): reduction_axes = (-1,) feature_axes = (-1,) rms_sq = self._compute_rms_sq(x, reduction_axes) return self._normalize( self, x, rms_sq, reduction_axes, feature_axes, self.dtype, self.param_dtype, self.epsilon, self.use_scale, self.scale_init, ) def _compute_rms_sq(self, x, axes): x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x))) rms_sq = jnp.mean(jax.lax.square(x), axes) return rms_sq def _normalize( self, mdl, x, rms_sq, reduction_axes, feature_axes, dtype, param_dtype, epsilon, use_scale, scale_init, ): reduction_axes = nn.normalization._canonicalize_axes(x.ndim, reduction_axes) feature_axes = nn.normalization._canonicalize_axes(x.ndim, feature_axes) stats_shape = list(x.shape) for axis in reduction_axes: stats_shape[axis] = 1 rms_sq = rms_sq.reshape(stats_shape) feature_shape = [1] * x.ndim reduced_feature_shape = [] for ax in feature_axes: feature_shape[ax] = x.shape[ax] reduced_feature_shape.append(x.shape[ax]) mul = lax.rsqrt(rms_sq + epsilon) if use_scale: scale = mdl.param( "scale", scale_init, reduced_feature_shape, param_dtype ).reshape(feature_shape) mul *= scale y = mul * x return jnp.asarray(y, dtype) def norm(type, *args, **kwargs): if type == "rmsnorm": return RMSNorm(*args, **kwargs) elif type == "layernorm": return nn.LayerNorm(*args, **kwargs) else: raise ValueError(f"Unknown norm type {type}") def dot_product_attention_weights( query: Any, key: Any, bias: Optional[Any] = None, mask: Optional[Any] = None, embed_pos: Optional[Any] = None, broadcast_dropout: bool = True, dropout_rng: Optional[PRNGKey] = None, dropout_rate: float = 0.0, deterministic: bool = False, dtype: Any = jnp.float32, precision: PrecisionLike = None, sinkhorn_iters: int = 1, is_encoder: bool = False, tau=None, ): """ Computes dot-product attention weights given query and key. mask is included into the bias. Adapted from flax.linen.attention.dot_product_attention_weights" """ assert query.ndim == key.ndim, "q, k must have same rank." assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match." assert query.shape[-2] == key.shape[-2], "q, k num_heads must match." assert query.shape[-1] == key.shape[-1], "q, k depths must match." # attn weight shape is (batch..., num_heads, q_length, kv_length) attn_weights = jnp.einsum("...qhd,...khd->...hqk", query, key, precision=precision) # divide by tau (used in Swin v2) if tau is not None: attn_weights = attn_weights / tau else: depth = query.shape[-1] attn_weights = attn_weights / jnp.sqrt(depth).astype(dtype) # apply attention bias: masking, dropout, proximity bias, etc. if bias is not None: attn_weights = attn_weights + bias # add relative position if embed_pos is not None: attn_weights = attn_weights + embed_pos # normalize the attention weights if not is_encoder or sinkhorn_iters == 1: # sinkhorn does not work for causal (leaks info of future tokens into past) attn_weights = jax.nn.softmax(attn_weights).astype(dtype) else: # adapted from https://github.com/lucidrains/sinkhorn-transformer for i in range(sinkhorn_iters): # when causal, some attn_weights have been set to -inf through bias if i % 2 == 0: attn_weights -= jax.nn.logsumexp(attn_weights, axis=-1, keepdims=True) else: attn_weights -= jax.nn.logsumexp(attn_weights, axis=-2, keepdims=True) if mask is not None: attn_weights = jnp.where(mask, attn_weights, -jnp.inf) attn_weights = jnp.exp(attn_weights).astype(dtype) # apply attention dropout if not deterministic and dropout_rate > 0.0: keep_prob = 1.0 - dropout_rate if broadcast_dropout: # dropout is broadcast across the batch + head dimensions dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:] keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape) else: keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) multiplier = keep.astype(attn_weights.dtype) / jnp.asarray( keep_prob, dtype=dtype ) attn_weights = attn_weights * multiplier return attn_weights class FlaxBartAttention(FlaxBartAttention): """ Edits: - causal mask is used only in decoder and considers image_length - scale attention heads per NormFormer paper """ is_encoder: bool = False is_cross_attention: bool = False q_length: int = None k_length: int = None def setup(self) -> None: self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {self.num_heads})." ) dense = partial( nn.Dense, self.embed_dim, use_bias=self.bias, dtype=self.dtype, ) if self.config.use_deepnet_scaling: gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"]( self.config ) elif self.config.use_subln_init and not self.is_cross_attention: gain = subln_gain["encoder" if self.is_encoder else "decoder"](self.config) self.q_proj = dense( kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.k_proj = dense( kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.v_proj = dense( kernel_init=deepnet_init(self.config.init_std, gain) if ( self.config.use_deepnet_scaling or (self.config.use_subln_init and not self.is_cross_attention) ) else jax.nn.initializers.normal(self.config.init_std) ) self.out_proj = dense( kernel_init=deepnet_init(self.config.init_std, gain) if ( self.config.use_deepnet_scaling or (self.config.use_subln_init and not self.is_cross_attention) ) else jax.nn.initializers.normal(self.config.init_std) ) self.dropout_layer = nn.Dropout(rate=self.dropout) if self.config.use_head_scale: self.head_scale = self.param( "head_scale", jax.nn.initializers.ones, (1, 1, self.num_heads, 1) ) if self.config.use_cosine_attention: # TODO: try using a learnt scale, somehow it immediately diverges in my experiments self.tau = self.config.tau_init if self.config.use_swin_position_embeddings: self.rel_bias = nn.Embed( self.q_length, self.k_length * self.num_heads, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) if self.causal: # used only in decoder self.causal_mask = make_causal_mask( jnp.ones((1, self.config.image_length), dtype="bool"), dtype="bool" ) if self.config.ln_positions in ["subln"] and not self.is_cross_attention: self.mid_layernorm = norm( self.config.ln_type, dtype=self.dtype, epsilon=1e-05 ) def __call__( self, hidden_states: jnp.ndarray, key_value_states: Optional[jnp.ndarray] = None, attention_mask: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None batch_size = hidden_states.shape[0] # get query proj query_states = self.q_proj(hidden_states) # get key, value proj if is_cross_attention: # cross_attentions key_states = self.k_proj(key_value_states) value_states = self.v_proj(key_value_states) else: # self_attention key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = self._split_heads(query_states) key_states = self._split_heads(key_states) value_states = self._split_heads(value_states) # handle cache prepare causal attention mask if self.causal: query_length, key_length = query_states.shape[1], key_states.shape[1] if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"]["cached_key"].shape[1] causal_mask = lax.dynamic_slice( self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length), ) else: causal_mask = self.causal_mask[:, :, :query_length, :key_length] causal_mask = jnp.broadcast_to( causal_mask, (batch_size,) + causal_mask.shape[1:] ) # combine masks if needed if attention_mask is not None and self.causal: attention_mask = jnp.broadcast_to( jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape ) attention_mask = combine_masks(attention_mask, causal_mask) elif self.causal: attention_mask = causal_mask elif attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.causal and (self.has_variable("cache", "cached_key") or init_cache): key_states, value_states, attention_mask = self._concatenate_to_cache( key_states, value_states, query_states, attention_mask ) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -jnp.inf).astype(self.dtype), ) else: attention_bias = None dropout_rng = None if not deterministic and self.dropout > 0.0: dropout_rng = self.make_rng("dropout") if self.config.use_cosine_attention: # normalize q and k query_states = query_states / ( jnp.linalg.norm(query_states, axis=-1, keepdims=True) + 1e-8 ) key_states = key_states / ( jnp.linalg.norm(key_states, axis=-1, keepdims=True) + 1e-8 ) # relative position embeddings if self.config.use_swin_position_embeddings: position_ids = jnp.arange(self.q_length) embed_pos = self.rel_bias(position_ids) embed_pos = rearrange(embed_pos, "q (k h) -> 1 h q k", h=self.num_heads) else: embed_pos = None tau = self.tau if self.config.use_cosine_attention else None attn_weights = dot_product_attention_weights( query_states, key_states, bias=attention_bias, mask=attention_mask, embed_pos=embed_pos, dropout_rng=dropout_rng, dropout_rate=self.dropout, broadcast_dropout=True, deterministic=deterministic, dtype=self.dtype, precision=None, sinkhorn_iters=self.config.sinkhorn_iters, is_encoder=self.is_encoder, tau=tau, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) if self.config.use_head_scale: # per Normformer attn_output = attn_output * self.head_scale attn_output = self._merge_heads(attn_output) if self.config.ln_positions in ["subln"] and not self.is_cross_attention: attn_output = self.mid_layernorm(attn_output) attn_output = self.out_proj(attn_output) return attn_output, attn_weights class GLU(nn.Module): """From "GLU Variants Improve Transformer" by https://arxiv.org/abs/2002.05202""" config: DalleBartConfig ffn_dim: int embed_dim: int dtype: jnp.dtype = jnp.float32 is_encoder: bool = False @nn.compact def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: if self.config.use_deepnet_scaling: gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"]( self.config ) elif self.config.use_subln_init: gain = subln_gain["encoder" if self.is_encoder else "decoder"](self.config) if self.config.ln_positions in ["normformer", "cogview", "preln", "subln"]: x = norm( self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=self.config.force_ln_scale, )(x) w = nn.Dense( self.ffn_dim, dtype=self.dtype, use_bias=self.config.use_bias, kernel_init=deepnet_init(self.config.init_std, gain) if (self.config.use_deepnet_scaling or self.config.use_subln_init) else jax.nn.initializers.normal(self.config.init_std), )(x) w = ACT2FN[self.config.activation_function](w) v = nn.Dense( self.ffn_dim, dtype=self.dtype, use_bias=self.config.use_bias, kernel_init=deepnet_init(self.config.init_std, gain) if (self.config.use_deepnet_scaling or self.config.use_subln_init) else jax.nn.initializers.normal(self.config.init_std), )(x) x = w * v if self.config.ln_positions in ["normformer", "subln"]: x = norm( self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=self.config.force_ln_scale, )(x) x = nn.Dropout(rate=self.config.activation_dropout)( x, deterministic=deterministic ) x = nn.Dense( self.embed_dim, dtype=self.dtype, use_bias=self.config.use_bias, kernel_init=deepnet_init(self.config.init_std, gain) if (self.config.use_deepnet_scaling or self.config.use_subln_init) else jax.nn.initializers.normal(self.config.init_std), )(x) if self.config.ln_positions in ["swinv2", "cogview"]: x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x) x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic) return x class FFN(nn.Module): """Simple FFN layer""" config: DalleBartConfig ffn_dim: int embed_dim: int dtype: jnp.dtype = jnp.float32 is_encoder: bool = False @nn.compact def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: if self.config.use_deepnet_scaling: gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"]( self.config ) elif self.config.use_subln_init: gain = subln_gain["encoder" if self.is_encoder else "decoder"](self.config) if self.config.ln_positions in ["normformer", "cogview", "preln", "subln"]: x = norm( self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=self.config.force_ln_scale, )(x) x = nn.Dense( self.ffn_dim, dtype=self.dtype, use_bias=self.config.use_bias, kernel_init=deepnet_init(self.config.init_std, gain) if (self.config.use_deepnet_scaling or self.config.use_subln_init) else jax.nn.initializers.normal(self.config.init_std), )(x) x = ACT2FN[self.config.activation_function](x) if self.config.ln_positions in ["normformer", "subln"]: x = norm( self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=self.config.force_ln_scale, )(x) x = nn.Dropout(rate=self.config.activation_dropout)( x, deterministic=deterministic ) x = nn.Dense( self.embed_dim, dtype=self.dtype, use_bias=self.config.use_bias, kernel_init=deepnet_init(self.config.init_std, gain) if (self.config.use_deepnet_scaling or self.config.use_subln_init) else jax.nn.initializers.normal(self.config.init_std), )(x) if self.config.ln_positions in ["swinv2", "cogview"]: x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x) x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic) return x class FlaxBartEncoderLayer(nn.Module): """ Edits: - no bias - use custom FlaxBartAttention """ config: DalleBartConfig dtype: jnp.dtype = jnp.float32 add_norm: bool = False use_scale: bool = True @nn.compact def __call__( self, hidden_states: jnp.ndarray, attention_mask: jnp.ndarray, output_attentions: bool = True, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: if self.config.use_scan: hidden_states = hidden_states[0] res_gain = ( deepnet_gain["encoder"]["alpha"](self.config) if self.config.use_deepnet_scaling else 1 ) embed_dim = self.config.d_model residual = hidden_states if self.config.ln_positions in ["normformer", "cogview", "preln", "subln"]: hidden_states = norm( self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=self.config.force_ln_scale, )(hidden_states) hidden_states, attn_weights = FlaxBartAttention( config=self.config, embed_dim=embed_dim, num_heads=self.config.encoder_attention_heads, dropout=self.config.attention_dropout, bias=self.config.use_bias, dtype=self.dtype, is_encoder=True, is_cross_attention=False, q_length=self.config.max_text_length, k_length=self.config.max_text_length, )(hidden_states=hidden_states, attention_mask=attention_mask) if self.config.ln_positions in ["normformer", "swinv2", "cogview"]: hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)( hidden_states ) hidden_states = nn.Dropout(rate=self.config.dropout)( hidden_states, deterministic=deterministic ) hidden_states = residual * res_gain + hidden_states if self.config.ln_positions in ["postln"]: hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)( hidden_states ) residual = hidden_states ff_block = ( GLU( config=self.config, ffn_dim=self.config.encoder_ffn_dim, embed_dim=embed_dim, dtype=self.dtype, is_encoder=True, ) if self.config.use_glu else FFN( config=self.config, ffn_dim=self.config.encoder_ffn_dim, embed_dim=embed_dim, dtype=self.dtype, is_encoder=True, ) ) hidden_states = ff_block(hidden_states, deterministic=deterministic) hidden_states = residual * res_gain + hidden_states if self.add_norm: use_scale = self.use_scale or self.config.force_ln_scale hidden_states = norm( self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=use_scale, )(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) if self.config.use_scan: outputs = (outputs, None) return outputs class FlaxBartDecoderLayer(nn.Module): """ Edits: - no bias - use custom FlaxBartAttention """ config: DalleBartConfig dtype: jnp.dtype = jnp.float32 add_norm: bool = False use_scale: bool = True @nn.compact def __call__( self, hidden_states: jnp.ndarray, attention_mask: jnp.ndarray, encoder_hidden_states: Optional[jnp.ndarray] = None, encoder_attention_mask: Optional[jnp.ndarray] = None, init_cache: bool = False, output_attentions: bool = True, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: if self.config.use_scan: hidden_states = hidden_states[0] res_gain = ( deepnet_gain["decoder"]["alpha"](self.config) if self.config.use_deepnet_scaling else 1 ) embed_dim = self.config.d_model residual = hidden_states # Self Attention if self.config.ln_positions in ["normformer", "cogview", "preln"]: hidden_states = norm( self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=self.config.force_ln_scale, )(hidden_states) hidden_states, attn_weights = FlaxBartAttention( config=self.config, embed_dim=embed_dim, num_heads=self.config.decoder_attention_heads, dropout=self.config.attention_dropout, causal=True, bias=self.config.use_bias, dtype=self.dtype, is_encoder=False, is_cross_attention=False, q_length=self.config.image_length, k_length=self.config.image_length, )( hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache, ) if self.config.ln_positions in ["normformer", "swinv2", "cogview"]: hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)( hidden_states ) hidden_states = nn.Dropout(rate=self.config.dropout)( hidden_states, deterministic=deterministic ) hidden_states = residual * res_gain + hidden_states if self.config.ln_positions in ["postln"]: hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)( hidden_states ) # Cross Attention cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states if self.config.ln_positions in ["normformer", "cogview", "preln"]: hidden_states = norm( self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=self.config.force_ln_scale, )(hidden_states) hidden_states, cross_attn_weights = FlaxBartAttention( config=self.config, embed_dim=embed_dim, num_heads=self.config.decoder_attention_heads, dropout=self.config.attention_dropout, bias=self.config.use_bias, dtype=self.dtype, is_encoder=False, is_cross_attention=True, q_length=self.config.image_length, k_length=self.config.max_text_length, )( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, ) if self.config.ln_positions in ["normformer", "swinv2", "cogview"]: hidden_states = norm( self.config.ln_type, dtype=self.dtype, epsilon=1e-05 )(hidden_states) hidden_states = nn.Dropout(rate=self.config.dropout)( hidden_states, deterministic=deterministic ) hidden_states = residual * res_gain + hidden_states if self.config.ln_positions in ["postln"]: hidden_states = norm( self.config.ln_type, dtype=self.dtype, epsilon=1e-05 )(hidden_states) # Feed forward residual = hidden_states ff_block = ( GLU( config=self.config, ffn_dim=self.config.decoder_ffn_dim, embed_dim=embed_dim, dtype=self.dtype, is_encoder=False, ) if self.config.use_glu else FFN( config=self.config, ffn_dim=self.config.decoder_ffn_dim, embed_dim=embed_dim, dtype=self.dtype, is_encoder=False, ) ) hidden_states = ff_block(hidden_states, deterministic=deterministic) hidden_states = residual * res_gain + hidden_states if self.add_norm: use_scale = self.use_scale or self.config.force_ln_scale hidden_states = norm( self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=use_scale, )(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights, cross_attn_weights) if self.config.use_scan: outputs = (outputs, None) return outputs class FlaxBartEncoderLayerCollection(nn.Module): config: DalleBartConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation """ Edits: - use custom FlaxBartEncoderLayer - allow Gradient Checkpointing (nn.remat) """ @nn.compact def __call__( self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None n_layers = self.config.encoder_layers layer = ( remat( FlaxBartEncoderLayer, static_argnums=(2, 3), prevent_cse=not self.config.use_scan, ) if self.config.gradient_checkpointing else FlaxBartEncoderLayer ) if self.config.use_scan: # all blocks are the same so we use nn.scan assert not output_attentions, "cannot scan with output_attentions" assert not output_hidden_states, "cannot scan with output_hidden_states" hidden_states = (hidden_states,) # we use a scale on all norms (even last layer) to allow scanning hidden_states, _ = nn.scan( layer, variable_axes={"params": 0, "cache": 0}, split_rngs={"params": True, "dropout": True}, in_axes=(nn.broadcast, nn.broadcast, nn.broadcast), length=n_layers, )( self.config, dtype=self.dtype, add_norm=self.config.ln_positions == "postln", name="FlaxBartEncoderLayers", )( hidden_states, attention_mask, output_attentions, deterministic, ) hidden_states = hidden_states[0] else: for i in range(n_layers): if output_hidden_states: all_hidden_states += (hidden_states,) # final layernorm on the output of the last layer # or every 6 layers for Swin v2 add_norm = self.config.ln_positions == "postln" or ( self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0) and (i != n_layers - 1) ) # we don't need to scale the norm for the last layer use_scale = i != n_layers - 1 layer_outputs = layer( self.config, dtype=self.dtype, add_norm=add_norm, use_scale=use_scale, name=f"FlaxBartEncoderLayer_{i}", )( hidden_states, attention_mask, output_attentions, deterministic, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) # add hidden states from the last layer if output_hidden_states: all_hidden_states += (hidden_states,) outputs = [ hidden_states, all_hidden_states, all_self_attns, ] if not return_dict: return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attns, ) class FlaxBartDecoderLayerCollection(nn.Module): config: DalleBartConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation """ Edits: - use custom FlaxBartDecoderLayer - allow Gradient Checkpointing (nn.remat) """ @nn.compact def __call__( self, hidden_states, attention_mask, encoder_hidden_states: Optional[jnp.ndarray] = None, encoder_attention_mask: Optional[jnp.ndarray] = None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = ( () if (output_attentions and encoder_hidden_states is not None) else None ) n_layers = self.config.decoder_layers layer = ( remat( FlaxBartDecoderLayer, static_argnums=(4, 5, 6), prevent_cse=not self.config.use_scan, ) if self.config.gradient_checkpointing else FlaxBartDecoderLayer ) if self.config.use_scan: # all blocks are the same so we use nn.scan assert not output_attentions, "cannot scan with output_attentions" assert not output_hidden_states, "cannot scan with output_hidden_states" hidden_states = (hidden_states,) # we use a scale on all norms (even last layer) to allow scanning hidden_states, _ = nn.scan( layer, variable_axes={"params": 0, "cache": 0}, split_rngs={"params": True, "dropout": True}, in_axes=( nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, ), length=n_layers, )( self.config, dtype=self.dtype, add_norm=self.config.ln_positions == "postln", name="FlaxBartDecoderLayers", )( hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, init_cache, output_attentions, deterministic, ) hidden_states = hidden_states[0] else: for i in range(n_layers): if output_hidden_states: all_hidden_states += (hidden_states,) # final layernorm on the output of the last layer # or every 6 layers for Swin v2 add_norm = self.config.ln_positions == "postln" or ( self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0) and (i != n_layers - 1) ) # we don't need to scale the norm for the last layer use_scale = i != n_layers - 1 layer_outputs = layer( self.config, dtype=self.dtype, add_norm=add_norm, use_scale=use_scale, name=f"FlaxBartDecoderLayer_{i}", )( hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, init_cache, output_attentions, deterministic, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) if encoder_hidden_states is not None: all_cross_attentions += (layer_outputs[2],) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) outputs = [ hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, ] if not return_dict: return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, ) class FlaxBartEncoder(nn.Module): config: DalleBartConfig embed_tokens: nn.Embed dtype: jnp.dtype = jnp.float32 # the dtype of the computation """ Edits: - offset set to 0 (no padding token) - use max_text_length instead of max_position_embeddings - use custom FlaxBartEncoderLayerCollection - embed_tokens cannot be None (issue at compile time) """ def setup(self): self.dropout_layer = nn.Dropout(rate=self.config.dropout) embed_dim = self.config.d_model self.padding_idx = self.config.pad_token_id self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack self.offset = 0 if self.config.use_absolute_position_embeddings: self.embed_positions = nn.Embed( self.config.max_text_length + self.offset, # image length for BOS embed_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype) self.layernorm_embedding = norm( self.config.ln_type, dtype=self.dtype, epsilon=1e-05 ) # postln is already applied in every layer if self.config.use_final_ln_encoder and self.config.ln_positions != "postln": self.final_ln = norm( self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=self.config.force_ln_scale, ) else: self.final_ln = None def __call__( self, input_ids, attention_mask, position_ids, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, ): input_shape = input_ids.shape input_ids = input_ids.reshape(-1, input_shape[-1]) hidden_states = self.embed_tokens(input_ids) * self.embed_scale if self.config.use_absolute_position_embeddings: embed_pos = self.embed_positions(position_ids + self.offset) hidden_states = hidden_states + embed_pos hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) outputs = self.layers( hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if self.final_ln is None: final_output = outputs[0] else: final_output = self.final_ln(outputs[0]) if not return_dict: return (final_output,) + outputs[1:] return FlaxBaseModelOutput( last_hidden_state=final_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class FlaxBartDecoder(nn.Module): config: DalleBartConfig embed_tokens: nn.Embed dtype: jnp.dtype = jnp.float32 # the dtype of the computation """ Edits: - offset set to 0 (no padding token) - use image_length instead of max_position_embeddings - use custom FlaxBartDecoderLayerCollection - embed_tokens cannot be None (issue at compile time) """ def setup(self): self.dropout_layer = nn.Dropout(rate=self.config.dropout) embed_dim = self.config.d_model self.padding_idx = self.config.pad_token_id self.embed_scale = ( math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 ) # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack self.offset = 0 if self.config.use_absolute_position_embeddings: self.embed_positions = nn.Embed( self.config.image_length + self.offset, # image length for BOS embed_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype) self.layernorm_embedding = norm( self.config.ln_type, dtype=self.dtype, epsilon=1e-05 ) # postln is already applied in every layer if self.config.use_final_ln_decoder and self.config.ln_positions != "postln": self.final_ln = norm( self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=self.config.force_ln_scale, ) def __call__( self, input_ids, attention_mask, position_ids, encoder_hidden_states: Optional[jnp.ndarray] = None, encoder_attention_mask: Optional[jnp.ndarray] = None, init_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, ): input_shape = input_ids.shape input_ids = input_ids.reshape(-1, input_shape[-1]) hidden_states = self.embed_tokens(input_ids) * self.embed_scale if self.config.use_absolute_position_embeddings: embed_pos = self.embed_positions(position_ids + self.offset) hidden_states = hidden_states + embed_pos hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) outputs = self.layers( hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if self.final_ln is None: final_output = outputs[0] else: final_output = self.final_ln(outputs[0]) if not return_dict: return (final_output,) + outputs[1:] return FlaxBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=final_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) class FlaxBartModule(FlaxBartModule): """ Edits - use custom FlaxBartEncoder & FlaxBartDecoder - use separate embeddings for Encoder & Decoder """ def setup(self): encoder_embed_tokens = nn.Embed( self.config.encoder_vocab_size, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) decoder_embed_tokens = nn.Embed( self.config.image_vocab_size + 1, # image vocab size + 1 for BOS self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.encoder = FlaxBartEncoder( self.config, dtype=self.dtype, embed_tokens=encoder_embed_tokens ) self.decoder = FlaxBartDecoder( self.config, dtype=self.dtype, embed_tokens=decoder_embed_tokens ) class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule): """ Edits: - no bias - lm_head set to image_vocab_size + 1 (for BOS) - uses custom FlaxBartModule """ def setup(self): self.model = FlaxBartModule(config=self.config, dtype=self.dtype) self.lm_head = nn.Dense( self.config.image_vocab_size + 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding) use_bias=False, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) def __call__( self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, position_ids, decoder_position_ids, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, ): outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, position_ids=position_ids, decoder_position_ids=decoder_position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, ) hidden_states = outputs[0] if self.config.tie_word_embeddings: shared_embedding = self.model.variables["params"]["shared"]["embedding"] lm_logits = self.lm_head.apply( {"params": {"kernel": shared_embedding.T}}, hidden_states ) else: lm_logits = self.lm_head(hidden_states) if not return_dict: output = (lm_logits,) + outputs[1:] return output return FlaxSeq2SeqLMOutput( logits=lm_logits, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, ) @flax.struct.dataclass class SampleState: cur_len: jnp.ndarray sequences: jnp.ndarray running_token: jnp.ndarray is_sent_finished: jnp.ndarray prng_key: jnp.ndarray model_kwargs: Dict[str, jnp.ndarray] model_kwargs_uncond: Dict[str, jnp.ndarray] @flax.struct.dataclass class FlaxSampleOutput(ModelOutput): """ Flax Base class for outputs of decoder-only generation models using sampling. Args: sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): The generated sequences. """ sequences: jnp.ndarray = None class DalleBart(PretrainedFromWandbMixin, FlaxBartForConditionalGeneration): """ Edits: - renamed from FlaxBartForConditionalGeneration - uses custom FlaxBartForConditionalGenerationModule - no bias in decode method - custom prepare_inputs_for_generation using "max_length - 1" to avoid issues related to position embedding during model.generate() - custom generate method to allow super conditions - num_params property - unscan function """ module_class = FlaxBartForConditionalGenerationModule config_class = DalleBartConfig def num_params(self, params=None): if params is None: params = self.params num_params = jax.tree_util.tree_map( lambda param: param.size, flatten_dict(unfreeze(params)) ).values() return sum(list(num_params)) def unscan(self, params): if self.config.use_scan: self.config.use_scan = False params = flatten_dict(params) scanned_keys = [k for k in params.keys() if "layers" in k] for k in scanned_keys: v = params[k] name_idx = k.index("layers") + 1 for i in range(len(v)): new_k = ( *k[:name_idx], f"{k[name_idx][:-1]}_{i}", *k[name_idx + 1 :], ) params[new_k] = v[i] del params[k] params = unflatten_dict(params) return params def decode( self, decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, decoder_position_ids: Optional[jnp.ndarray] = None, past_key_values: dict = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.return_dict ) encoder_hidden_states = encoder_outputs[0] if encoder_attention_mask is None: batch_size, sequence_length = encoder_hidden_states.shape[:2] encoder_attention_mask = jnp.ones((batch_size, sequence_length)) batch_size, sequence_length = decoder_input_ids.shape if decoder_attention_mask is None: decoder_attention_mask = jnp.ones((batch_size, sequence_length)) if decoder_position_ids is None: if past_key_values is not None: raise ValueError( "Make sure to provide `decoder_position_ids` when passing `past_key_values`." ) decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that # it can be changed by FlaxBartAttention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False def _decoder_forward( module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs, ): decoder_module = module._get_decoder_module() outputs = decoder_module( decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs, ) hidden_states = outputs[0] if self.config.tie_word_embeddings: shared_embedding = module.model.variables["params"]["shared"][ "embedding" ] lm_logits = module.lm_head.apply( {"params": {"kernel": shared_embedding.T}}, hidden_states ) else: lm_logits = module.lm_head(hidden_states) return lm_logits, outputs outputs = self.module.apply( inputs, decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, mutable=mutable, method=_decoder_forward, ) if past_key_values is None: lm_logits, decoder_outputs = outputs else: (lm_logits, decoder_outputs), past = outputs if return_dict: outputs = FlaxCausalLMOutputWithCrossAttentions( logits=lm_logits, hidden_states=decoder_outputs.hidden_states, attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, ) else: outputs = (lm_logits,) + decoder_outputs[1:] # add updated cache to model output if past_key_values is not None and return_dict: outputs["past_key_values"] = unfreeze(past["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] return outputs def prepare_inputs_for_generation( self, decoder_input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None, decoder_attention_mask: Optional[jnp.DeviceArray] = None, encoder_outputs=None, **kwargs, ): # initializing the cache batch_size, seq_length = decoder_input_ids.shape past_key_values = self.init_cache(batch_size, max_length - 1, encoder_outputs) # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. # But since the decoder uses a causal mask, those positions are masked anyways. # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length - 1), dtype="i4") if decoder_attention_mask is not None: position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice( extended_attention_mask, decoder_attention_mask, (0, 0) ) else: position_ids = jnp.broadcast_to( jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) ) return { "past_key_values": past_key_values, "encoder_outputs": encoder_outputs, "encoder_attention_mask": attention_mask, "decoder_attention_mask": extended_attention_mask, "decoder_position_ids": position_ids, } def generate( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, bos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, decoder_start_token_id: Optional[int] = None, do_sample: Optional[bool] = None, prng_key: Optional[jnp.ndarray] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None, num_beams: Optional[int] = None, no_repeat_ngram_size: Optional[int] = None, min_length: Optional[int] = None, forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, length_penalty: Optional[float] = None, early_stopping: Optional[bool] = None, trace: bool = True, params: Optional[Dict[str, jnp.ndarray]] = None, condition_scale: Optional[float] = 1.0, input_ids_uncond: Optional[jnp.ndarray] = None, attention_mask_uncond: Optional[jnp.ndarray] = None, **model_kwargs, ): """Edit: Allow super conditioning.""" # set init values max_length = max_length if max_length is not None else self.config.max_length bos_token_id = ( bos_token_id if bos_token_id is not None else self.config.bos_token_id ) pad_token_id = ( pad_token_id if pad_token_id is not None else self.config.pad_token_id ) eos_token_id = ( eos_token_id if eos_token_id is not None else self.config.eos_token_id ) decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id else self.config.decoder_start_token_id ) prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) if decoder_start_token_id is None and self.config.is_encoder_decoder: raise ValueError( "`decoder_start_token_id` has to be defined for encoder-decoder generation." ) do_sample = do_sample if do_sample is not None else self.config.do_sample num_beams = num_beams if num_beams is not None else self.config.num_beams if self.config.is_encoder_decoder: # add encoder_outputs to model_kwargs if model_kwargs.get("encoder_outputs") is None: model_kwargs_input = dict(model_kwargs) model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( input_ids, params, {"attention_mask": attention_mask, **model_kwargs_input}, ) if condition_scale != 1.0: assert ( input_ids_uncond is not None ), "`input_ids_uncond` has to be defined for super conditioning." assert ( do_sample is True ), "`do_sample` has to be True for super conditioning." assert ( num_beams == 1 ), "`num_beams` has to be 1 for super conditioning." model_kwargs_uncond = ( self._prepare_encoder_decoder_kwargs_for_generation( input_ids_uncond, params, { "attention_mask": attention_mask_uncond, **model_kwargs_input, }, ) ) else: model_kwargs_uncond = None # prepare decoder_input_ids for generation input_ids = ( jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id ) if not do_sample and num_beams == 1: logits_processor = self._get_logits_processor( no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id, ) return self._greedy_search( input_ids, max_length, pad_token_id, eos_token_id, logits_processor=logits_processor, trace=trace, params=params, model_kwargs=model_kwargs, ) elif do_sample and num_beams == 1: logits_warper = self._get_logits_warper( top_k=top_k, top_p=top_p, temperature=temperature ) logits_processor = self._get_logits_processor( no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id, ) return self._sample( input_ids, max_length, pad_token_id, eos_token_id, prng_key, logits_warper=logits_warper, logits_processor=logits_processor, trace=trace, params=params, model_kwargs=model_kwargs, condition_scale=condition_scale, model_kwargs_uncond=model_kwargs_uncond, ) elif not do_sample and num_beams > 1: # broadcast input_ids & encoder_outputs input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams) if "encoder_outputs" in model_kwargs: model_kwargs["encoder_outputs"][ "last_hidden_state" ] = self._expand_to_num_beams( model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams, ) if "attention_mask" in model_kwargs: model_kwargs["attention_mask"] = self._expand_to_num_beams( model_kwargs["attention_mask"], num_beams=num_beams ) logits_processor = self._get_logits_processor( no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id, ) return self._beam_search( input_ids, max_length, pad_token_id, eos_token_id, length_penalty=length_penalty, early_stopping=early_stopping, logits_processor=logits_processor, trace=trace, params=params, model_kwargs=model_kwargs, ) else: raise NotImplementedError("`Beam sampling is currently not implemented.") def _sample( self, input_ids: None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, prng_key: Optional[jnp.ndarray] = None, logits_processor=None, logits_warper=None, trace: bool = True, params: Optional[Dict[str, jnp.ndarray]] = None, model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, condition_scale: float = 1.0, model_kwargs_uncond: Optional[Dict[str, jnp.ndarray]] = None, ): # init values max_length = max_length if max_length is not None else self.config.max_length pad_token_id = ( pad_token_id if pad_token_id is not None else self.config.pad_token_id ) eos_token_id = ( eos_token_id if eos_token_id is not None else self.config.eos_token_id ) prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) batch_size, cur_len = input_ids.shape eos_token_id = jnp.array(eos_token_id) pad_token_id = jnp.array(pad_token_id) cur_len = jnp.array(cur_len) # per batch-item holding current token in loop. sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32) sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0)) # per batch-item state bit indicating if sentence has finished. is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_) # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop # and pass it the `encoder_outputs`, which are part of the `model_kwargs`. model = self.decode if self.config.is_encoder_decoder else self # initialize model specific kwargs model_kwargs = self.prepare_inputs_for_generation( input_ids, max_length, **model_kwargs ) if condition_scale != 1.0: model_kwargs_uncond = self.prepare_inputs_for_generation( input_ids, max_length, **model_kwargs_uncond ) # initialize state state = SampleState( cur_len=cur_len, sequences=sequences, running_token=input_ids, is_sent_finished=is_sent_finished, prng_key=prng_key, model_kwargs=model_kwargs, model_kwargs_uncond=model_kwargs_uncond, ) def sample_search_cond_fn(state): """state termination condition fn.""" has_reached_max_length = state.cur_len == max_length all_sequence_finished = jnp.all(state.is_sent_finished) finish_generation = jnp.logical_or( has_reached_max_length, all_sequence_finished ) return ~finish_generation def sample_search_body_fn(state): """state update fn.""" prng_key, prng_key_next = jax.random.split(state.prng_key) model_outputs = model( state.running_token, params=params, **state.model_kwargs ) logits = model_outputs.logits[:, -1] # perform super conditioning # Source: @RiversHaveWings - https://twitter.com/RiversHaveWings/status/1478093658716966912?s=20&t=xdm-wZ61Wf7OLnE_NJHZ1w if condition_scale != 1.0: model_outputs_uncond = model( state.running_token, params=params, **state.model_kwargs_uncond ) logits_uncond = model_outputs_uncond.logits[:, -1] logits = logits_uncond + condition_scale * (logits - logits_uncond) else: model_outputs_uncond = None # apply min_length, ... logits = logits_processor(state.sequences, logits, state.cur_len) # apply top_k, top_k, temperature logits = logits_warper(logits, logits, state.cur_len) next_token = jax.random.categorical(prng_key, logits, axis=-1) next_is_sent_finished = state.is_sent_finished | ( next_token == eos_token_id ) next_token = ( next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished ) next_token = next_token[:, None] next_sequences = lax.dynamic_update_slice( state.sequences, next_token, (0, state.cur_len) ) next_model_kwargs = self.update_inputs_for_generation( model_outputs, state.model_kwargs ) next_model_kwargs_uncond = ( self.update_inputs_for_generation( model_outputs_uncond, state.model_kwargs_uncond ) if condition_scale != 1.0 else None ) return SampleState( cur_len=state.cur_len + 1, sequences=next_sequences, running_token=next_token, is_sent_finished=next_is_sent_finished, model_kwargs=next_model_kwargs, model_kwargs_uncond=next_model_kwargs_uncond, prng_key=prng_key_next, ) # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU if input_ids.shape[1] > 1: state = sample_search_body_fn(state) if not trace: state = self._run_loop_in_debug( sample_search_cond_fn, sample_search_body_fn, state ) else: state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state) return FlaxSampleOutput(sequences=state.sequences) ================================================ FILE: src/dalle_mini/model/partitions.py ================================================ import re from flax.core.frozen_dict import freeze from flax.traverse_util import flatten_dict, unflatten_dict from jax.experimental import PartitionSpec as P # utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py # Sentinels _unmatched = object() # For specifying empty leaf dict `{}` empty_dict = object() def _match(qs, ks): """Return True if regexes in qs match any window of strings in tuple ks.""" # compile regexes and force complete match qts = tuple(map(lambda x: re.compile(x + "$"), qs)) for i in range(len(ks) - len(qs) + 1): matches = [x.match(y) for x, y in zip(qts, ks[i:])] if matches and all(matches): return True return False def _replacement_rules(rules): def replace(key, val): for rule, replacement in rules: if _match(rule, key): return replacement return val return replace def _get_partition_rules(): return [ # embeddings (("embed_positions", "embedding"), P("mp", None)), (("embed_tokens", "embedding"), P("mp", None)), (("rel_bias", "embedding"), P(None, "mp")), # attention (("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")), (("out_proj", "kernel"), P("mp", None)), # FFN (("Dense_0", "kernel"), P(None, "mp")), (("GLU.*", "Dense_1", "kernel"), P(None, "mp")), (("GLU.*", "Dense_2", "kernel"), P("mp", None)), (("FFN.*", "Dense_1", "kernel"), P("mp", None)), # layer norms (("(bias|scale)",), None), (("lm_head", "kernel"), P(None, "mp")), # head scale and tau (("(head_scale|tau)",), None), ] def set_partitions(in_dict, use_scan): rules = _get_partition_rules() replace = _replacement_rules(rules) initd = {k: _unmatched for k in flatten_dict(in_dict)} result = {k: replace(k, v) for k, v in initd.items()} for k, v in result.items(): if v == _unmatched: print(f"Unmatched -> {k}") l = list(result.keys()) if use_scan: # add None dimension to layers result = { k: (P(*(None,) + v) if v is not None else None) if any(x in k for x in ["FlaxBartEncoderLayers", "FlaxBartDecoderLayers"]) else v for k, v in result.items() } assert _unmatched not in result.values(), "Incomplete partition spec." return freeze(unflatten_dict(result)) ================================================ FILE: src/dalle_mini/model/processor.py ================================================ """ DalleBart processor """ from typing import List import jax.numpy as jnp from .configuration import DalleBartConfig from .text import TextNormalizer from .tokenizer import DalleBartTokenizer from .utils import PretrainedFromWandbMixin class DalleBartProcessorBase: def __init__( self, tokenizer: DalleBartTokenizer, normalize_text: bool, max_text_length: int ): self.tokenizer = tokenizer self.normalize_text = normalize_text self.max_text_length = max_text_length if normalize_text: self.text_processor = TextNormalizer() # create unconditional tokens uncond = self.tokenizer( "", return_tensors="jax", padding="max_length", truncation=True, max_length=self.max_text_length, ).data self.input_ids_uncond = uncond["input_ids"] self.attention_mask_uncond = uncond["attention_mask"] def __call__(self, text: List[str] = None): # check that text is not a string assert not isinstance(text, str), "text must be a list of strings" if self.normalize_text: text = [self.text_processor(t) for t in text] res = self.tokenizer( text, return_tensors="jax", padding="max_length", truncation=True, max_length=self.max_text_length, ).data # tokens used only with super conditioning n = len(text) res["input_ids_uncond"] = jnp.repeat(self.input_ids_uncond, n, axis=0) res["attention_mask_uncond"] = jnp.repeat(self.attention_mask_uncond, n, axis=0) return res @classmethod def from_pretrained(cls, *args, **kwargs): tokenizer = DalleBartTokenizer.from_pretrained(*args, **kwargs) config = DalleBartConfig.from_pretrained(*args, **kwargs) return cls(tokenizer, config.normalize_text, config.max_text_length) class DalleBartProcessor(PretrainedFromWandbMixin, DalleBartProcessorBase): pass ================================================ FILE: src/dalle_mini/model/text.py ================================================ """ Utilities for processing text. """ import html import math import random import re from pathlib import Path import emoji import ftfy from huggingface_hub import hf_hub_download from unidecode import unidecode # based on wiki word occurrence person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)] temp_token = "xtokx" # avoid repeating chars class HashtagProcessor: # Adapted from wordninja library # We use our wikipedia word count + a good heuristic to make it work def __init__(self): wiki_word_frequency = hf_hub_download( "dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt" ) self._word_cost = ( l.split()[0] for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines() ) self._word_cost = { str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost) } self._max_word = max(len(x) for x in self._word_cost.keys()) self._SPLIT_RE = re.compile("[^a-zA-Z0-9']+") def __call__(self, s): """Uses dynamic programming to infer the location of spaces in a string without spaces.""" l = [self._split(x) for x in self._SPLIT_RE.split(s)] return " ".join([item for sublist in l for item in sublist]) def _split(self, s): # Find the best match for the i first characters, assuming cost has # been built for the i-1 first characters. # Returns a pair (match_cost, match_length). def best_match(i): candidates = enumerate(reversed(cost[max(0, i - self._max_word) : i])) return min( (c + self._word_cost.get(s[i - k - 1 : i].lower(), 9e999), k + 1) for k, c in candidates ) # Build the cost array cost = [0] for i in range(1, len(s) + 1): c, k = best_match(i) cost.append(c) # Backtrack to recover the minimal-cost string. out = [] i = len(s) while i > 0: c, k = best_match(i) assert c == cost[i] newToken = True if not s[i - k : i] == "'": # ignore a lone apostrophe if len(out) > 0: # re-attach split 's and split digits if out[-1] == "'s" or ( s[i - 1].isdigit() and out[-1][0].isdigit() ): # digit followed by digit out[-1] = ( s[i - k : i] + out[-1] ) # combine current token with previous token newToken = False if newToken: out.append(s[i - k : i]) i -= k return reversed(out) def replace_person_token(t): "Used for CC12M" t = re.sub("([,\s]*(and)*[,\s]*)+", " people ", t) while "" in t: t = t.replace( "", f" {random.choices(*tuple(zip(*person_token)))[0]} ", 1 ) return t def fix_html(t): # from OpenAI CLIP return html.unescape(html.unescape(t)) def replace_punctuation_with_commas(t): return re.sub("[()[\].,|:;?!=+~\-\/{}]", ",", t) def simplify_quotes(t): return re.sub("""['"`]""", ' " ', t) def merge_quotes(t): return re.sub('(\s*"+\s*)+', ' " ', t) def remove_comma_numbers(t): def _f(t): return re.sub("(\d),(\d{3})", r"\1\2", t) return _f(_f(t)) def pre_process_dot_numbers(t): return re.sub("(\w)\.(\w)", rf"\1{temp_token}dot{temp_token}\2", t) def post_process_dot_numbers(t): return re.sub(f"{temp_token}dot{temp_token}", ".", t) def pre_process_quotes(t): # allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've return re.sub( r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", rf"{temp_token}quote{temp_token}", t ) def post_process_quotes(t): return re.sub(f"{temp_token}quote{temp_token}", "'", t) def pre_process_dates(t): return re.sub("(\d)/(\d)", rf"\1{temp_token}slash{temp_token}\2", t) def post_process_dates(t): return re.sub(f"{temp_token}slash{temp_token}", "/", t) def merge_commas(t): return re.sub("(\s*,+\s*)+", ", ", t) def add_space_after_commas(t): return re.sub(",", ", ", t) def handle_special_chars(t): "Handle special characters" # replace "-" with a space when between words without space t = re.sub("(\w)-(\w)", r"\1 \2", t) # always add space around some characters return re.sub("([%&\/$*])", r" \1 ", t) def expand_hashtags(t, hashtag_processor): "Remove # and try to split words" return re.sub("#(\w+)", lambda m: hashtag_processor(m.group(1)), t) _re_ignore_chars = r"[_#\\]" def ignore_chars(t): "Ignore useless characters" return re.sub(_re_ignore_chars, " ", t) def remove_extra_spaces(t): "Remove extra spaces (including \t and \n)" return re.sub("\s+", " ", t) def remove_repeating_chars(t): "If the same character is present 4+ times (not 3 because of roman 'VIII'), replace with single instance" return re.sub(r"(\D)(\1{3,})", r"\1", t) def remove_urls(t): return re.sub(r"http\S+", "", t) def remove_html_tags(t): return re.sub("<[^<]+?>", " ", t) def remove_first_last_commas(t): t = t.strip() t = t[:-1] if t and t[-1] == "," else t t = t[1:] if t and t[0] == "," else t return t.strip() def remove_wiki_ref(t): t = re.sub(r"\A\s*\[\d+\]", "", t) return re.sub(r"\[\d+\]\s*\Z", "", t) class TextNormalizer: "Normalize text" def __init__(self): self._hashtag_processor = HashtagProcessor() def __call__(self, t): # fix some characters t = ftfy.fix_text(t) # fix html t = fix_html(t) # decode emojis (would be removed by unidecode) t = emoji.demojize(t) # decode and simplify text: see unidecode library t = unidecode(t) # lower case t = t.lower() # replace (for CC12M) t = replace_person_token(t) # remove wiki reference (for WIT) t = remove_wiki_ref(t) # remove html tags t = remove_html_tags(t) # remove urls t = remove_urls(t) # remove commas in numbers t = remove_comma_numbers(t) # handle dots in numbers and quotes - Part 1 t = pre_process_dot_numbers(t) t = pre_process_quotes(t) t = pre_process_dates(t) # handle special characters t = handle_special_chars(t) # handle hashtags t = expand_hashtags(t, self._hashtag_processor) # ignore useless characters t = ignore_chars(t) # simplify quotes t = simplify_quotes(t) # all punctuation becomes commas t = replace_punctuation_with_commas(t) # handle dots in numbers and quotes - Part 2 t = post_process_dot_numbers(t) t = post_process_quotes(t) t = post_process_dates(t) # handle repeating characters t = remove_repeating_chars(t) # merge quotes t = merge_quotes(t) # merge commas t = merge_commas(t) # remove multiple spaces t = remove_extra_spaces(t) # remove first and last comma t = remove_first_last_commas(t) # always start with a space return f" {t}" ================================================ FILE: src/dalle_mini/model/tokenizer.py ================================================ """ DalleBart tokenizer """ from transformers import BartTokenizerFast from .utils import PretrainedFromWandbMixin class DalleBartTokenizer(PretrainedFromWandbMixin, BartTokenizerFast): pass ================================================ FILE: src/dalle_mini/model/utils.py ================================================ import os import tempfile from pathlib import Path import wandb class PretrainedFromWandbMixin: @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """ Initializes from a wandb artifact or delegates loading to the superclass. """ with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies if ":" in pretrained_model_name_or_path and not os.path.isdir( pretrained_model_name_or_path ): # wandb artifact if wandb.run is not None: artifact = wandb.run.use_artifact(pretrained_model_name_or_path) else: artifact = wandb.Api().artifact(pretrained_model_name_or_path) pretrained_model_name_or_path = artifact.download(tmp_dir) return super(PretrainedFromWandbMixin, cls).from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs ) ================================================ FILE: tools/dataset/encode_dataset.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "d0b72877", "metadata": {}, "source": [ "# Pre-encoding a dataset for DALLE·mini" ] }, { "cell_type": "markdown", "id": "ba7b31e6", "metadata": {}, "source": [ "This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n", "\n", "Adapt it to your own dataset and image encoder.\n", "\n", "At the end you should have a dataset of pairs:\n", "* a caption defined as a string\n", "* an encoded image defined as a list of int." ] }, { "cell_type": "code", "execution_count": null, "id": "3b59489e", "metadata": {}, "outputs": [], "source": [ "from tqdm.notebook import tqdm\n", "\n", "import torchvision.transforms as T\n", "\n", "import webdataset as wds\n", "\n", "import jax\n", "import braceexpand\n", "from pathlib import Path" ] }, { "cell_type": "markdown", "id": "c7c4c1e6", "metadata": {}, "source": [ "## Configuration Parameters" ] }, { "cell_type": "code", "execution_count": 3, "id": "1265dbfe", "metadata": {}, "outputs": [], "source": [ "shards = \"my_images/shard-{0000..0008}.tar\" # defined using braceexpand format as used by webdataset\n", "encoded_output = Path(\"encoded_data\") # where we will save our encoded data\n", "\n", "VQGAN_REPO, VQGAN_COMMIT_ID = (\n", " \"dalle-mini/vqgan_imagenet_f16_16384\",\n", " \"85eb5d3b51a1c62a0cc8f4ccdee9882c0d0bd384\",\n", ")\n", "\n", "# good defaults for a TPU v3-8\n", "batch_size = 128 # Per device\n", "num_workers = 8 # For parallel processing\n", "total_bs = batch_size * jax.device_count() # You can use a smaller size while testing\n", "save_frequency = 128 # Number of batches to create a new file (180MB for f16 and 720MB for f8 per file)" ] }, { "cell_type": "code", "execution_count": 5, "id": "cd956ec6-7d98-4d4d-a454-f80fe857eadd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['XXX/shard-0000.tar',\n", " 'XXX/shard-0001.tar',\n", " 'XXX/shard-0002.tar',\n", " 'XXX/shard-0003.tar',\n", " 'XXX/shard-0004.tar',\n", " 'XXX/shard-0005.tar',\n", " 'XXX/shard-0006.tar',\n", " 'XXX/shard-0007.tar',\n", " 'XXX/shard-0008.tar']" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "shards = list(\n", " braceexpand.braceexpand(shards)\n", ") # better display for tqdm with known length" ] }, { "cell_type": "markdown", "id": "75dba8e2", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "markdown", "id": "a1e8fb95", "metadata": {}, "source": [ "We load data using `webdataset`." ] }, { "cell_type": "code", "execution_count": null, "id": "9ef5de9e", "metadata": {}, "outputs": [], "source": [ "ds = (\n", " wds.WebDataset(shards, handler=wds.warn_and_continue)\n", " .decode(\"rgb\", handler=wds.warn_and_continue)\n", " .to_tuple(\"jpg\", \"txt\") # assumes image is in `jpg` and caption in `txt`\n", " .batched(total_bs) # load in batch per worker (faster)\n", ")" ] }, { "cell_type": "markdown", "id": "90981824", "metadata": {}, "source": [ "Note:\n", "* you can also shuffle shards and items using `shardshuffle` and `shuffle` if necessary.\n", "* you may need to resize images in your pipeline (with `map_dict` for example), we assume they are already set to 256x256.\n", "* you can also filter out some items using `select`." ] }, { "cell_type": "markdown", "id": "129c377d", "metadata": {}, "source": [ "We can now inspect our data." ] }, { "cell_type": "code", "execution_count": null, "id": "8cac98cb", "metadata": { "scrolled": true }, "outputs": [], "source": [ "%%time\n", "images, captions = next(iter(ds))" ] }, { "cell_type": "code", "execution_count": null, "id": "cd268fbf", "metadata": {}, "outputs": [], "source": [ "images.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "5acfc4d8", "metadata": {}, "outputs": [], "source": [ "captions[:10]" ] }, { "cell_type": "code", "execution_count": null, "id": "c24693c0", "metadata": {}, "outputs": [], "source": [ "T.ToPILImage()(images[0].permute(2, 0, 1))" ] }, { "cell_type": "markdown", "id": "3059ffb1", "metadata": {}, "source": [ "Finally we create our dataloader." ] }, { "cell_type": "code", "execution_count": null, "id": "c227c551", "metadata": {}, "outputs": [], "source": [ "dl = (\n", " wds.WebLoader(ds, batch_size=None, num_workers=8).unbatched().batched(total_bs)\n", ") # avoid partial batch at the end of each worker" ] }, { "cell_type": "markdown", "id": "a354472b", "metadata": {}, "source": [ "## Image encoder\n", "\n", "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model." ] }, { "cell_type": "code", "execution_count": null, "id": "47a8b818", "metadata": { "scrolled": true }, "outputs": [], "source": [ "from vqgan_jax.modeling_flax_vqgan import VQModel\n", "from flax.jax_utils import replicate\n", "\n", "vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")\n", "vqgan_params = replicate(vqgan.params)" ] }, { "cell_type": "markdown", "id": "62ad01c3", "metadata": {}, "source": [ "## Encoding" ] }, { "cell_type": "markdown", "id": "20357f74", "metadata": {}, "source": [ "Encoding is really simple using `shard` to automatically distribute batches across devices and `pmap`." ] }, { "cell_type": "code", "execution_count": null, "id": "322a4619", "metadata": {}, "outputs": [], "source": [ "from flax.training.common_utils import shard\n", "from functools import partial\n", "\n", "\n", "@partial(jax.pmap, axis_name=\"batch\")\n", "def p_encode(batch, params):\n", " # Not sure if we should `replicate` params, does not seem to have any effect\n", " _, indices = vqgan.encode(batch, params=params)\n", " return indices" ] }, { "cell_type": "code", "execution_count": null, "id": "ff6c10d4", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "\n", "def encode_dataset(dataloader, output_dir, save_frequency):\n", " output_dir.mkdir(parents=True, exist_ok=True)\n", " all_captions = []\n", " all_encoding = []\n", " n_file = 1\n", " for idx, (images, captions) in enumerate(tqdm(dataloader)):\n", " images = images.numpy()\n", " n = len(images) // 8 * 8\n", " if n != len(images):\n", " # get the max number of images we can (multiple of 8)\n", " print(f\"Different sizes {n} vs {len(images)}\")\n", " images = images[:n]\n", " captions = captions[:n]\n", " if not len(captions):\n", " print(f\"No images/captions in batch...\")\n", " continue\n", " images = shard(images)\n", " encoded = p_encode(images, vqgan_params)\n", " encoded = encoded.reshape(-1, encoded.shape[-1])\n", " all_captions.extend(captions)\n", " all_encoding.extend(encoded.tolist())\n", "\n", " # save files\n", " if (idx + 1) % save_frequency == 0:\n", " print(f\"Saving file {n_file}\")\n", " batch_df = pd.DataFrame.from_dict(\n", " {\"caption\": all_captions, \"encoding\": all_encoding}\n", " )\n", " batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")\n", " all_captions = []\n", " all_encoding = []\n", " n_file += 1\n", "\n", " if len(all_captions):\n", " print(f\"Saving final file {n_file}\")\n", " batch_df = pd.DataFrame.from_dict(\n", " {\"caption\": all_captions, \"encoding\": all_encoding}\n", " )\n", " batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")" ] }, { "cell_type": "code", "execution_count": null, "id": "7704863d", "metadata": {}, "outputs": [], "source": [ "encode_dataset(dl, output_dir=encoded_output, save_frequency=save_frequency)" ] }, { "cell_type": "markdown", "id": "8953dd84", "metadata": {}, "source": [ "----" ] } ], "metadata": { "interpreter": { "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: tools/inference/inference_pipeline.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "118UKH5bWCGa" }, "source": [ "# DALL·E mini - Inference pipeline\n", "\n", "*Generate images from a text prompt*\n", "\n", "\n", "\n", "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n", "\n", "Just want to play? Use directly [the app](https://www.craiyon.com/).\n", "\n", "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)." ] }, { "cell_type": "markdown", "metadata": { "id": "dS8LbaonYm3a" }, "source": [ "## 🛠️ Installation and set-up" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uzjAM2GBYpZX" }, "outputs": [], "source": [ "# Required only for colab environments + GPU\n", "!pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", "\n", "# Install required libraries\n", "!pip install -q dalle-mini\n", "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git" ] }, { "cell_type": "markdown", "metadata": { "id": "ozHzTkyv8cqU" }, "source": [ "We load required models:\n", "* DALL·E mini for text to encoded images\n", "* VQGAN for decoding images\n", "* CLIP for scoring predictions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "K6CxW2o42f-w" }, "outputs": [], "source": [ "# Model references\n", "\n", "# dalle-mega\n", "DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n", "DALLE_COMMIT_ID = None\n", "\n", "# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n", "# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n", "\n", "# VQGAN model\n", "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n", "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Yv-aR3t4Oe5v" }, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "\n", "# check how many devices are available\n", "jax.local_device_count()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "92zYmvsQ38vL" }, "outputs": [], "source": [ "# Load models & tokenizer\n", "from dalle_mini import DalleBart, DalleBartProcessor\n", "from vqgan_jax.modeling_flax_vqgan import VQModel\n", "from transformers import CLIPProcessor, FlaxCLIPModel\n", "\n", "# Load dalle-mini\n", "model, params = DalleBart.from_pretrained(\n", " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n", ")\n", "\n", "# Load VQGAN\n", "vqgan, vqgan_params = VQModel.from_pretrained(\n", " VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "o_vH2X1tDtzA" }, "source": [ "Model parameters are replicated on each device for faster inference." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wtvLoM48EeVw" }, "outputs": [], "source": [ "from flax.jax_utils import replicate\n", "\n", "params = replicate(params)\n", "vqgan_params = replicate(vqgan_params)" ] }, { "cell_type": "markdown", "metadata": { "id": "0A9AHQIgZ_qw" }, "source": [ "Model functions are compiled and parallelized to take advantage of multiple devices." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sOtoOmYsSYPz" }, "outputs": [], "source": [ "from functools import partial\n", "\n", "\n", "# model inference\n", "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n", "def p_generate(\n", " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n", "):\n", " return model.generate(\n", " **tokenized_prompt,\n", " prng_key=key,\n", " params=params,\n", " top_k=top_k,\n", " top_p=top_p,\n", " temperature=temperature,\n", " condition_scale=condition_scale,\n", " )\n", "\n", "\n", "# decode image\n", "@partial(jax.pmap, axis_name=\"batch\")\n", "def p_decode(indices, params):\n", " return vqgan.decode_code(indices, params=params)" ] }, { "cell_type": "markdown", "metadata": { "id": "HmVN6IBwapBA" }, "source": [ "Keys are passed to the model on each device to generate unique inference per device." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4CTXmlUkThhX" }, "outputs": [], "source": [ "import random\n", "\n", "# create a random key\n", "seed = random.randint(0, 2**32 - 1)\n", "key = jax.random.PRNGKey(seed)" ] }, { "cell_type": "markdown", "metadata": { "id": "BrnVyCo81pij" }, "source": [ "## 🖍 Text Prompt" ] }, { "cell_type": "markdown", "metadata": { "id": "rsmj0Aj5OQox" }, "source": [ "Our model requires processing prompts." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YjjhUychOVxm" }, "outputs": [], "source": [ "from dalle_mini import DalleBartProcessor\n", "\n", "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)" ] }, { "cell_type": "markdown", "metadata": { "id": "BQ7fymSPyvF_" }, "source": [ "Let's define some text prompts." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x_0vI9ge1oKr" }, "outputs": [], "source": [ "prompts = [\n", " \"sunset over a lake in the mountains\",\n", " \"the Eiffel tower landing on the moon\",\n", "]" ] }, { "cell_type": "markdown", "metadata": { "id": "XlZUG3SCLnGE" }, "source": [ "Note: we could use the same prompt multiple times for faster inference." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VKjEZGjtO49k" }, "outputs": [], "source": [ "tokenized_prompts = processor(prompts)" ] }, { "cell_type": "markdown", "metadata": { "id": "-CEJBnuJOe5z" }, "source": [ "Finally we replicate the prompts onto each device." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lQePgju5Oe5z" }, "outputs": [], "source": [ "tokenized_prompt = replicate(tokenized_prompts)" ] }, { "cell_type": "markdown", "metadata": { "id": "phQ9bhjRkgAZ" }, "source": [ "## 🎨 Generate images\n", "\n", "We generate images using dalle-mini model and decode them with the VQGAN." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "d0wVkXpKqnHA" }, "outputs": [], "source": [ "# number of predictions per prompt\n", "n_predictions = 8\n", "\n", "# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)\n", "gen_top_k = None\n", "gen_top_p = None\n", "temperature = None\n", "cond_scale = 10.0" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SDjEx9JxR3v8" }, "outputs": [], "source": [ "from flax.training.common_utils import shard_prng_key\n", "import numpy as np\n", "from PIL import Image\n", "from tqdm.notebook import trange\n", "\n", "print(f\"Prompts: {prompts}\\n\")\n", "# generate images\n", "images = []\n", "for i in trange(max(n_predictions // jax.device_count(), 1)):\n", " # get a new key\n", " key, subkey = jax.random.split(key)\n", " # generate images\n", " encoded_images = p_generate(\n", " tokenized_prompt,\n", " shard_prng_key(subkey),\n", " params,\n", " gen_top_k,\n", " gen_top_p,\n", " temperature,\n", " cond_scale,\n", " )\n", " # remove BOS\n", " encoded_images = encoded_images.sequences[..., 1:]\n", " # decode images\n", " decoded_images = p_decode(encoded_images, vqgan_params)\n", " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n", " for decoded_img in decoded_images:\n", " img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n", " images.append(img)\n", " display(img)\n", " print()" ] }, { "cell_type": "markdown", "metadata": { "id": "tw02wG9zGmyB" }, "source": [ "## 🏅 Optional: Rank images by CLIP score\n", "\n", "We can rank images according to CLIP.\n", "\n", "**Note: your session may crash if you don't have a subscription to Colab Pro.**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RGjlIW_f6GA0" }, "outputs": [], "source": [ "# CLIP model\n", "CLIP_REPO = \"openai/clip-vit-base-patch32\"\n", "CLIP_COMMIT_ID = None\n", "\n", "# Load CLIP\n", "clip, clip_params = FlaxCLIPModel.from_pretrained(\n", " CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n", ")\n", "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n", "clip_params = replicate(clip_params)\n", "\n", "\n", "# score images\n", "@partial(jax.pmap, axis_name=\"batch\")\n", "def p_clip(inputs, params):\n", " logits = clip(params=params, **inputs).logits_per_image\n", " return logits" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FoLXpjCmGpju" }, "outputs": [], "source": [ "from flax.training.common_utils import shard\n", "\n", "# get clip scores\n", "clip_inputs = clip_processor(\n", " text=prompts * jax.device_count(),\n", " images=images,\n", " return_tensors=\"np\",\n", " padding=\"max_length\",\n", " max_length=77,\n", " truncation=True,\n", ").data\n", "logits = p_clip(shard(clip_inputs), clip_params)\n", "\n", "# organize scores per prompt\n", "p = len(prompts)\n", "logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()" ] }, { "cell_type": "markdown", "metadata": { "id": "4AAWRm70LgED" }, "source": [ "Let's now display images ranked by CLIP score." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zsgxxubLLkIu" }, "outputs": [], "source": [ "for i, prompt in enumerate(prompts):\n", " print(f\"Prompt: {prompt}\\n\")\n", " for idx in logits[i].argsort()[::-1]:\n", " display(images[idx * p + i])\n", " print(f\"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\\n\")\n", " print()" ] }, { "cell_type": "markdown", "metadata": { "id": "oZT9i3jCjir0" }, "source": [ "## 🪄 Optional: Save your Generated Images as W&B Tables\n", "\n", "W&B Tables is an interactive 2D grid with support to rich media logging. Use this to save the generated images on W&B dashboard and share with the world." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-pSiv6Vwjkn0" }, "outputs": [], "source": [ "import wandb\n", "\n", "# Initialize a W&B run.\n", "project = \"dalle-mini-tables-colab\"\n", "run = wandb.init(project=project)\n", "\n", "# Initialize an empty W&B Tables.\n", "columns = [\"captions\"] + [f\"image_{i+1}\" for i in range(n_predictions)]\n", "gen_table = wandb.Table(columns=columns)\n", "\n", "# Add data to the table.\n", "for i, prompt in enumerate(prompts):\n", " # If CLIP scores exist, sort the Images\n", " if logits is not None:\n", " idxs = logits[i].argsort()[::-1]\n", " tmp_imgs = images[i :: len(prompts)]\n", " tmp_imgs = [tmp_imgs[idx] for idx in idxs]\n", " else:\n", " tmp_imgs = images[i :: len(prompts)]\n", "\n", " # Add the data to the table.\n", " gen_table.add_data(prompt, *[wandb.Image(img) for img in tmp_imgs])\n", "\n", "# Log the Table to W&B dashboard.\n", "wandb.log({\"Generated Images\": gen_table})\n", "\n", "# Close the W&B run.\n", "run.finish()" ] }, { "cell_type": "markdown", "metadata": { "id": "Ck2ZnHwVjnRd" }, "source": [ "Click on the link above to check out your generated images." ] } ], "metadata": { "accelerator": "GPU", "colab": { "machine_shape": "hm", "name": "DALL·E mini - Inference pipeline.ipynb", "provenance": [], "gpuType": "A100", "include_colab_link": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: tools/inference/run_infer_notebook.sh ================================================ #!/bin/bash jupyter notebook --ip 0.0.0.0 --no-browser --allow-root ================================================ FILE: tools/train/config/mega/config.json ================================================ { "activation_dropout": 0.0, "activation_function": "gelu", "attention_dropout": 0.0, "bos_token_id": 16385, "d_model": 2048, "decoder_attention_heads": 32, "decoder_ffn_dim": 4096, "decoder_layerdrop": 0.0, "decoder_layers": 24, "decoder_start_token_id": 16384, "do_sample": true, "dropout": 0.0, "encoder_attention_heads": 32, "encoder_ffn_dim": 4096, "encoder_layerdrop": 0.0, "encoder_layers": 24, "encoder_vocab_size": 50272, "eos_token_id": 16385, "force_ln_scale": false, "gradient_checkpointing": false, "image_length": 256, "image_vocab_size": 16415, "init_std": 0.01, "is_encoder_decoder": true, "ln_positions": "normformer", "ln_type": "layernorm", "max_length": 257, "max_text_length": 64, "min_length": 257, "model_type": "dallebart", "normalize_text": true, "pad_token_id": 16385, "scale_embedding": false, "sinkhorn_iters": 1, "tau_init": 0.05, "tie_word_embeddings": false, "use_absolute_position_embeddings": true, "use_alibi": false, "use_bias": false, "use_cache": true, "use_cosine_attention": false, "use_deepnet_scaling": false, "use_final_ln_decoder": true, "use_final_ln_encoder": true, "use_glu": true, "use_head_scale": false, "use_swin_position_embeddings": false } ================================================ FILE: tools/train/config/micro/config.json ================================================ { "activation_dropout": 0.0, "activation_function": "gelu", "attention_dropout": 0.0, "bos_token_id": 16385, "d_model": 256, "decoder_attention_heads": 2, "decoder_ffn_dim": 256, "decoder_layerdrop": 0.0, "decoder_layers": 2, "decoder_start_token_id": 16384, "dropout": 0.0, "encoder_attention_heads": 2, "encoder_ffn_dim": 256, "encoder_layerdrop": 0.0, "encoder_layers": 2, "encoder_vocab_size": 50264, "eos_token_id": 16385, "image_length": 256, "image_vocab_size": 16391, "init_std": 0.02, "is_encoder_decoder": true, "max_text_length": 64, "model_type": "dallebart", "normalize_text": true, "pad_token_id": 16385, "scale_embedding": false, "tie_word_embeddings": false, "use_cache": true } ================================================ FILE: tools/train/config/mini/config.json ================================================ { "activation_dropout": 0.0, "activation_function": "gelu", "attention_dropout": 0.0, "bos_token_id": 16385, "d_model": 1024, "decoder_attention_heads": 16, "decoder_ffn_dim": 4096, "decoder_layers": 12, "decoder_start_token_id": 16384, "dropout": 0.0, "encoder_attention_heads": 16, "encoder_ffn_dim": 4096, "encoder_layers": 12, "encoder_vocab_size": 50264, "eos_token_id": 16385, "gradient_checkpointing": false, "image_length": 256, "image_vocab_size": 16391, "init_std": 0.02, "is_encoder_decoder": true, "max_text_length": 64, "model_type": "dallebart", "normalize_text": true, "pad_token_id": 16385, "scale_embedding": false, "tie_word_embeddings": false, "use_cache": true } ================================================ FILE: tools/train/config/mini_glu/config.json ================================================ { "activation_dropout": 0.0, "activation_function": "gelu", "attention_dropout": 0.0, "bos_token_id": 16385, "d_model": 1024, "decoder_attention_heads": 16, "decoder_ffn_dim": 2730, "decoder_layers": 12, "decoder_start_token_id": 16384, "dropout": 0.0, "encoder_attention_heads": 16, "encoder_ffn_dim": 2730, "encoder_layers": 12, "encoder_vocab_size": 50300, "eos_token_id": 16385, "gradient_checkpointing": false, "image_length": 256, "image_vocab_size": 16400, "init_std": 0.02, "is_encoder_decoder": true, "max_text_length": 64, "model_type": "dallebart", "normalize_text": true, "pad_token_id": 16385, "scale_embedding": false, "tie_word_embeddings": false, "use_scan": false, "use_cache": true } ================================================ FILE: tools/train/embeddings_retrain_preparation.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "118UKH5bWCGa" }, "source": [ "# DALL·E mini - Embedding Retrain Preparation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll start with the dalle-mini model for faster experimentation." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "K6CxW2o42f-w" }, "outputs": [], "source": [ "DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n", "DALLE_COMMIT_ID = None\n", "\n", "# # dalle-mega\n", "# DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n", "# DALLE_COMMIT_ID = None" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "Yv-aR3t4Oe5v" }, "outputs": [ { "data": { "text/plain": [ "8" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import jax\n", "import jax.numpy as jnp\n", "\n", "# check how many devices are available\n", "jax.local_device_count()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We load the model twice to keep a copy of the original parameters." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "92zYmvsQ38vL" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact mini-1:v0, 1673.43MB. 7 files... Done. 0:0:1.2\n", "tcmalloc: large alloc 1751343104 bytes == 0x56011a2c0000 @ 0x7f143aaa9680 0x7f143aaca824 0x5600d248253b 0x5600d24c30ba 0x5600d2599a58 0x5600d24f548d 0x5600d23cf328 0x5600d25af66d 0x5600d24f5825 0x5600d24532da 0x5600d24eafe3 0x5600d24ec709 0x5600d249a1ea 0x5600d252be7a 0x5600d24eafe3 0x5600d24ec709 0x5600d245273d 0x5600d24eafe3 0x5600d2597a7c 0x5600d24ebdbb 0x5600d25ce33e 0x5600d24f5571 0x5600d2452088 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d24f5f94 0x5600d24532da 0x5600d24ebbe4\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact mini-1:v0, 1673.43MB. 7 files... Done. 0:0:1.2\n", "tcmalloc: large alloc 1751343104 bytes == 0x56011a2c0000 @ 0x7f143aaa9680 0x7f143aaca824 0x5600d248253b 0x5600d24c30ba 0x5600d2599a58 0x5600d24f548d 0x5600d23cf328 0x5600d25af66d 0x5600d24f5825 0x5600d24532da 0x5600d24eafe3 0x5600d24ec709 0x5600d249a1ea 0x5600d252be7a 0x5600d24eafe3 0x5600d24ec709 0x5600d245273d 0x5600d24eafe3 0x5600d2597a7c 0x5600d24ebdbb 0x5600d25ce33e 0x5600d24f5571 0x5600d2452088 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d24f5f94 0x5600d24532da 0x5600d24ebbe4\n" ] } ], "source": [ "# Load model\n", "from dalle_mini import DalleBart, DalleBartProcessor\n", "\n", "model, params = DalleBart.from_pretrained(\n", " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n", ")\n", "\n", "_, params_original = DalleBart.from_pretrained(\n", " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model surgery: remove layers to be retrained" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's take a look at the params tree." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "437833712" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum(x.size for x in jax.tree_leaves(params))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\n", " \"lm_head\": {\n", " \"kernel\": [\n", " 1024,\n", " 16385\n", " ]\n", " },\n", " \"model\": {\n", " \"decoder\": {\n", " \"embed_positions\": {\n", " \"embedding\": [\n", " 256,\n", " 1024\n", " ]\n", " },\n", " \"embed_tokens\": {\n", " \"embedding\": [\n", " 16385,\n", " 1024\n", " ]\n", " },\n", " \"final_ln\": {\n", " \"bias\": [\n", " 1024\n", " ]\n", " },\n", " \"layernorm_embedding\": {\n", " \"bias\": [\n", " 1024\n", " ],\n", " \"scale\": [\n", " 1024\n", " ]\n", " },\n", " \"layers\": {\n", " \"FlaxBartDecoderLayers\": {\n", " \"FlaxBartAttention_0\": {\n", " \"k_proj\": {\n", " \"kernel\": [\n", " 12,\n", " 1024,\n", " 1024\n", " ]\n", " },\n", " \"out_proj\": {\n", " \"kernel\": [\n", " 12,\n", " 1024,\n", " 1024\n", " ]\n", " },\n", " \"q_proj\": {\n", " \"kernel\": [\n", " 12,\n", " 1024,\n", " 1024\n", " ]\n", " },\n", " \"v_proj\": {\n", " \"kernel\": [\n", " 12,\n", " 1024,\n", " 1024\n", " ]\n", " }\n", " },\n", " \"FlaxBartAttention_1\": {\n", " \"k_proj\": {\n", " \"kernel\": [\n", " 12,\n", " 1024,\n", " 1024\n", " ]\n", " },\n", " \"out_proj\": {\n", " \"kernel\": [\n", " 12,\n", " 1024,\n", " 1024\n", " ]\n", " },\n", " \"q_proj\": {\n", " \"kernel\": [\n", " 12,\n", " 1024,\n", " 1024\n", " ]\n", " },\n", " \"v_proj\": {\n", " \"kernel\": [\n", " 12,\n", " 1024,\n", " 1024\n", " ]\n", " }\n", " },\n", " \"GLU_0\": {\n", " \"Dense_0\": {\n", " \"kernel\": [\n", " 12,\n", " 1024,\n", " 2730\n", " ]\n", " },\n", " \"Dense_1\": {\n", " \"kernel\": [\n", " 12,\n", " 1024,\n", " 2730\n", " ]\n", " },\n", " \"Dense_2\": {\n", " \"kernel\": [\n", " 12,\n", " 2730,\n", " 1024\n", " ]\n", " },\n", " \"LayerNorm_0\": {\n", " \"bias\": [\n", " 12,\n", " 1024\n", " ]\n", " },\n", " \"LayerNorm_1\": {\n", " \"bias\": [\n", " 12,\n", " 2730\n", " ]\n", " }\n", " },\n", " \"LayerNorm_0\": {\n", " \"bias\": [\n", " 12,\n", " 1024\n", " ]\n", " },\n", " \"LayerNorm_1\": {\n", " \"bias\": [\n", " 12,\n", " 1024\n", " ],\n", " \"scale\": [\n", " 12,\n", " 1024\n", " ]\n", " },\n", " \"LayerNorm_2\": {\n", " \"bias\": [\n", " 12,\n", " 1024\n", " ]\n", " },\n", " \"LayerNorm_3\": {\n", " \"bias\": [\n", " 12,\n", " 1024\n", " ],\n", " \"scale\": [\n", " 12,\n", " 1024\n", " ]\n", " }\n", " }\n", " }\n", " },\n", " \"encoder\": {\n", " \"embed_positions\": {\n", " \"embedding\": [\n", " 64,\n", " 1024\n", " ]\n", " },\n", " \"embed_tokens\": {\n", " \"embedding\": [\n", " 50264,\n", " 1024\n", " ]\n", " },\n", " \"final_ln\": {\n", " \"bias\": [\n", " 1024\n", " ]\n", " },\n", " \"layernorm_embedding\": {\n", " \"bias\": [\n", " 1024\n", " ],\n", " \"scale\": [\n", " 1024\n", " ]\n", " },\n", " \"layers\": {\n", " \"FlaxBartEncoderLayers\": {\n", " \"FlaxBartAttention_0\": {\n", " \"k_proj\": {\n", " \"kernel\": [\n", " 12,\n", " 1024,\n", " 1024\n", " ]\n", " },\n", " \"out_proj\": {\n", " \"kernel\": [\n", " 12,\n", " 1024,\n", " 1024\n", " ]\n", " },\n", " \"q_proj\": {\n", " \"kernel\": [\n", " 12,\n", " 1024,\n", " 1024\n", " ]\n", " },\n", " \"v_proj\": {\n", " \"kernel\": [\n", " 12,\n", " 1024,\n", " 1024\n", " ]\n", " }\n", " },\n", " \"GLU_0\": {\n", " \"Dense_0\": {\n", " \"kernel\": [\n", " 12,\n", " 1024,\n", " 2730\n", " ]\n", " },\n", " \"Dense_1\": {\n", " \"kernel\": [\n", " 12,\n", " 1024,\n", " 2730\n", " ]\n", " },\n", " \"Dense_2\": {\n", " \"kernel\": [\n", " 12,\n", " 2730,\n", " 1024\n", " ]\n", " },\n", " \"LayerNorm_0\": {\n", " \"bias\": [\n", " 12,\n", " 1024\n", " ]\n", " },\n", " \"LayerNorm_1\": {\n", " \"bias\": [\n", " 12,\n", " 2730\n", " ]\n", " }\n", " },\n", " \"LayerNorm_0\": {\n", " \"bias\": [\n", " 12,\n", " 1024\n", " ]\n", " },\n", " \"LayerNorm_1\": {\n", " \"bias\": [\n", " 12,\n", " 1024\n", " ],\n", " \"scale\": [\n", " 12,\n", " 1024\n", " ]\n", " }\n", " }\n", " }\n", " }\n", " }\n", "}\n" ] } ], "source": [ "import json\n", "\n", "tree = jax.tree_map(lambda x: x.shape, params)\n", "print(json.dumps(tree, indent=2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will remove or reinitialize:\n", "- `lm_head`\n", "- `model.decoder.embed_positions`\n", "- `model.decoder.embed_tokens`\n", "- `model.decoder.final_ln`\n", "- `model.decoder.layernorm_embedding`" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "del params[\"lm_head\"]\n", "for layer in [\"embed_positions\", \"embed_tokens\", \"final_ln\", \"layernorm_embedding\"]:\n", " del params[\"model\"][\"decoder\"][layer]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'model': {'decoder': {'layers': {'FlaxBartDecoderLayers': {'FlaxBartAttention_0': {'k_proj': {'kernel': (12,\n", " 1024,\n", " 1024)},\n", " 'out_proj': {'kernel': (12, 1024, 1024)},\n", " 'q_proj': {'kernel': (12, 1024, 1024)},\n", " 'v_proj': {'kernel': (12, 1024, 1024)}},\n", " 'FlaxBartAttention_1': {'k_proj': {'kernel': (12, 1024, 1024)},\n", " 'out_proj': {'kernel': (12, 1024, 1024)},\n", " 'q_proj': {'kernel': (12, 1024, 1024)},\n", " 'v_proj': {'kernel': (12, 1024, 1024)}},\n", " 'GLU_0': {'Dense_0': {'kernel': (12, 1024, 2730)},\n", " 'Dense_1': {'kernel': (12, 1024, 2730)},\n", " 'Dense_2': {'kernel': (12, 2730, 1024)},\n", " 'LayerNorm_0': {'bias': (12, 1024)},\n", " 'LayerNorm_1': {'bias': (12, 2730)}},\n", " 'LayerNorm_0': {'bias': (12, 1024)},\n", " 'LayerNorm_1': {'bias': (12, 1024), 'scale': (12, 1024)},\n", " 'LayerNorm_2': {'bias': (12, 1024)},\n", " 'LayerNorm_3': {'bias': (12, 1024), 'scale': (12, 1024)}}}},\n", " 'encoder': {'embed_positions': {'embedding': (64, 1024)},\n", " 'embed_tokens': {'embedding': (50264, 1024)},\n", " 'final_ln': {'bias': (1024,)},\n", " 'layernorm_embedding': {'bias': (1024,), 'scale': (1024,)},\n", " 'layers': {'FlaxBartEncoderLayers': {'FlaxBartAttention_0': {'k_proj': {'kernel': (12,\n", " 1024,\n", " 1024)},\n", " 'out_proj': {'kernel': (12, 1024, 1024)},\n", " 'q_proj': {'kernel': (12, 1024, 1024)},\n", " 'v_proj': {'kernel': (12, 1024, 1024)}},\n", " 'GLU_0': {'Dense_0': {'kernel': (12, 1024, 2730)},\n", " 'Dense_1': {'kernel': (12, 1024, 2730)},\n", " 'Dense_2': {'kernel': (12, 2730, 1024)},\n", " 'LayerNorm_0': {'bias': (12, 1024)},\n", " 'LayerNorm_1': {'bias': (12, 2730)}},\n", " 'LayerNorm_0': {'bias': (12, 1024)},\n", " 'LayerNorm_1': {'bias': (12, 1024), 'scale': (12, 1024)}}}}}}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jax.tree_map(lambda x: x.shape, params)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "404012016" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum(x.size for x in jax.tree_leaves(params))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reinitialize layers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We save a checkpoint and reload it again. It does not automatically reinitialize the missing keys, but it sets `_missing_keys` appropriately so we can initialize them later. We could do the same by simply setting that property ourselves, but I'll refrain from doing so because it's a private implementation detail." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "trimmed_checkpoint = \"mini-trimmed\"" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "tcmalloc: large alloc 1610424320 bytes == 0x5632d11c8000 @ 0x7f95ccad0680 0x7f95ccaf0bdd 0x7f95be99e29f 0x7f95be9a7750 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a4fc4 0x7f95be9a571e 0x5630fb630f94 0x5630fb58e2da 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb626be4 0x5630fb58d088 0x5630fb625fe3 0x5630fb627709 0x5630fb58d73d 0x5630fb625fe3 0x5630fb6d2a7c 0x5630fb626dbb 0x5630fb70933e 0x5630fb630571\n", "tcmalloc: large alloc 3231449088 bytes == 0x56333119a000 @ 0x7f95ccad0680 0x7f95ccaf0bdd 0x7f95be99e29f 0x7f95be9a7750 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a4fc4 0x7f95be9a571e 0x5630fb630f94 0x5630fb58e2da 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb626be4 0x5630fb58d088 0x5630fb625fe3 0x5630fb627709 0x5630fb58d73d 0x5630fb625fe3 0x5630fb6d2a7c 0x5630fb626dbb 0x5630fb70933e 0x5630fb630571\n" ] } ], "source": [ "model.save_pretrained(trimmed_checkpoint, params=params)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The checkpoint mini-trimmed is missing required keys: {('model', 'decoder', 'embed_tokens', 'embedding'), ('lm_head', 'kernel'), ('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'final_ln', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'scale'), ('model', 'decoder', 'layernorm_embedding', 'bias')}. Make sure to call model.init_weights to initialize the missing weights.\n", "Some weights of DalleBart were not initialized from the model checkpoint at mini-trimmed and are newly initialized: {('model', 'decoder', 'embed_tokens', 'embedding'), ('lm_head', 'kernel'), ('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'final_ln', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'scale'), ('model', 'decoder', 'layernorm_embedding', 'bias')}\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "model, params = DalleBart.from_pretrained(\n", " trimmed_checkpoint, revision=None, dtype=jnp.float16, _do_init=False\n", ")" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{('lm_head', 'kernel'),\n", " ('model', 'decoder', 'embed_positions', 'embedding'),\n", " ('model', 'decoder', 'embed_tokens', 'embedding'),\n", " ('model', 'decoder', 'final_ln', 'bias'),\n", " ('model', 'decoder', 'layernorm_embedding', 'bias'),\n", " ('model', 'decoder', 'layernorm_embedding', 'scale')}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model._missing_keys" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "params_reinit = model.init_weights(model.key, model.input_shape, params=params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Verification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The structure should be the same as the original `params` dict. Re-initialized layers should have different parameters, but existing layers should be the same." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "FrozenDict({\n", " lm_head: {\n", " kernel: (1024, 16385),\n", " },\n", " model: {\n", " decoder: {\n", " embed_positions: {\n", " embedding: (256, 1024),\n", " },\n", " embed_tokens: {\n", " embedding: (16385, 1024),\n", " },\n", " final_ln: {\n", " bias: (1024,),\n", " },\n", " layernorm_embedding: {\n", " bias: (1024,),\n", " scale: (1024,),\n", " },\n", " layers: {\n", " FlaxBartDecoderLayers: {\n", " FlaxBartAttention_0: {\n", " k_proj: {\n", " kernel: (12, 1024, 1024),\n", " },\n", " out_proj: {\n", " kernel: (12, 1024, 1024),\n", " },\n", " q_proj: {\n", " kernel: (12, 1024, 1024),\n", " },\n", " v_proj: {\n", " kernel: (12, 1024, 1024),\n", " },\n", " },\n", " FlaxBartAttention_1: {\n", " k_proj: {\n", " kernel: (12, 1024, 1024),\n", " },\n", " out_proj: {\n", " kernel: (12, 1024, 1024),\n", " },\n", " q_proj: {\n", " kernel: (12, 1024, 1024),\n", " },\n", " v_proj: {\n", " kernel: (12, 1024, 1024),\n", " },\n", " },\n", " GLU_0: {\n", " Dense_0: {\n", " kernel: (12, 1024, 2730),\n", " },\n", " Dense_1: {\n", " kernel: (12, 1024, 2730),\n", " },\n", " Dense_2: {\n", " kernel: (12, 2730, 1024),\n", " },\n", " LayerNorm_0: {\n", " bias: (12, 1024),\n", " },\n", " LayerNorm_1: {\n", " bias: (12, 2730),\n", " },\n", " },\n", " LayerNorm_0: {\n", " bias: (12, 1024),\n", " },\n", " LayerNorm_1: {\n", " bias: (12, 1024),\n", " scale: (12, 1024),\n", " },\n", " LayerNorm_2: {\n", " bias: (12, 1024),\n", " },\n", " LayerNorm_3: {\n", " bias: (12, 1024),\n", " scale: (12, 1024),\n", " },\n", " },\n", " },\n", " },\n", " encoder: {\n", " embed_positions: {\n", " embedding: (64, 1024),\n", " },\n", " embed_tokens: {\n", " embedding: (50264, 1024),\n", " },\n", " final_ln: {\n", " bias: (1024,),\n", " },\n", " layernorm_embedding: {\n", " bias: (1024,),\n", " scale: (1024,),\n", " },\n", " layers: {\n", " FlaxBartEncoderLayers: {\n", " FlaxBartAttention_0: {\n", " k_proj: {\n", " kernel: (12, 1024, 1024),\n", " },\n", " out_proj: {\n", " kernel: (12, 1024, 1024),\n", " },\n", " q_proj: {\n", " kernel: (12, 1024, 1024),\n", " },\n", " v_proj: {\n", " kernel: (12, 1024, 1024),\n", " },\n", " },\n", " GLU_0: {\n", " Dense_0: {\n", " kernel: (12, 1024, 2730),\n", " },\n", " Dense_1: {\n", " kernel: (12, 1024, 2730),\n", " },\n", " Dense_2: {\n", " kernel: (12, 2730, 1024),\n", " },\n", " LayerNorm_0: {\n", " bias: (12, 1024),\n", " },\n", " LayerNorm_1: {\n", " bias: (12, 2730),\n", " },\n", " },\n", " LayerNorm_0: {\n", " bias: (12, 1024),\n", " },\n", " LayerNorm_1: {\n", " bias: (12, 1024),\n", " scale: (12, 1024),\n", " },\n", " },\n", " },\n", " },\n", " },\n", "})" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jax.tree_map(lambda x: x.shape, params_reinit)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FrozenDict({\n", " embedding: DeviceArray([[ 0.00582082, -0.04113895, 0.00918633, ..., -0.00530822,\n", " 0.01297319, 0.02720674],\n", " [ 0.03540739, 0.03676804, -0.02924041, ..., 0.00163185,\n", " -0.01938273, -0.02105987],\n", " [ 0.00478452, -0.03438002, -0.0024974 , ..., -0.03892584,\n", " 0.01721252, 0.02605445],\n", " ...,\n", " [ 0.02495495, 0.00559381, -0.01588043, ..., 0.01393714,\n", " -0.01824111, -0.02007291],\n", " [ 0.00983252, -0.00180564, -0.01686333, ..., -0.01001718,\n", " 0.01886345, -0.00393983],\n", " [-0.03589988, -0.00455565, 0.00076276, ..., -0.02145007,\n", " -0.00180798, -0.0133148 ]], dtype=float32),\n", "})" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params_reinit[\"model\"][\"decoder\"][\"embed_positions\"]" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(DeviceArray(-0.09320386, dtype=float32),\n", " DeviceArray(0.08769083, dtype=float32))" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "embedding_new = params_reinit[\"model\"][\"decoder\"][\"embed_positions\"][\"embedding\"]\n", "embedding_new.min(), embedding_new.max()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'embedding': DeviceArray([[ 0.03459017, -0.0065838 , -0.11748601, ..., -0.01451578,\n", " -0.03927238, -0.00266367],\n", " [-0.03116009, 0.00438436, 0.02691377, ..., -0.02886203,\n", " -0.01095741, -0.02649871],\n", " [-0.03568491, -0.0086962 , 0.01851564, ..., -0.04736514,\n", " 0.05310551, -0.01648099],\n", " ...,\n", " [-0.02454913, 0.03746822, -0.02269235, ..., 0.03377315,\n", " 0.003004 , 0.04975331],\n", " [-0.05145862, 0.04472217, 0.11103845, ..., 0.04581303,\n", " 0.02850476, 0.00554514],\n", " [-0.01037806, 0.00281054, -0.0485299 , ..., -0.03325456,\n", " -0.0058979 , 0.01733843]], dtype=float32)}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params_original[\"model\"][\"decoder\"][\"embed_positions\"]" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(DeviceArray(-0.25866088, dtype=float32),\n", " DeviceArray(0.08769083, dtype=float32))" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "embedding_original = params_original[\"model\"][\"decoder\"][\"embed_positions\"][\"embedding\"]\n", "embedding_original.min(), embedding_new.max()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "assert jnp.allclose(embedding_new, embedding_original).item() == False" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "lm_head_original = params_original[\"lm_head\"][\"kernel\"]\n", "lm_head_reinit = params_reinit[\"lm_head\"][\"kernel\"]\n", "assert jnp.allclose(lm_head_reinit, lm_head_original).item() == False" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "assert jnp.allclose(\n", " params_reinit[\"model\"][\"encoder\"][\"layers\"][\"FlaxBartEncoderLayers\"][\n", " \"FlaxBartAttention_0\"\n", " ][\"k_proj\"][\"kernel\"],\n", " params_original[\"model\"][\"encoder\"][\"layers\"][\"FlaxBartEncoderLayers\"][\n", " \"FlaxBartAttention_0\"\n", " ][\"k_proj\"][\"kernel\"],\n", ").item()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Save checkpoint for retrain" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we save the resulting model to retrain those layers." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "checkpoint_dir = \"mini-reinit\"" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "tcmalloc: large alloc 3367796736 bytes == 0x5633f235a000 @ 0x7f95ccad0680 0x7f95ccaf0bdd 0x7f95be99e29f 0x7f95be9a7750 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a4fc4 0x7f95be9a571e 0x5630fb630f94 0x5630fb58e2da 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb626be4 0x5630fb58d088 0x5630fb625fe3 0x5630fb627709 0x5630fb58d73d 0x5630fb625fe3 0x5630fb6d2a7c 0x5630fb626dbb 0x5630fb70933e 0x5630fb630571\n" ] } ], "source": [ "model.save_pretrained(checkpoint_dir, params=params_reinit)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Upload checkpoint to W&B" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "import wandb\n", "from pathlib import Path" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mpcuenq\u001b[0m (\u001b[33mdalle-mini\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.12.21" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /home/pedro/code/dalle-mini/dalle-mini/tools/train/wandb/run-20220722_105625-2v9szi3q" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run astral-durian-2957 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wandb.init(\n", " entity=\"dalle-mini\",\n", " project=\"dalle-mini\",\n", " job_type=\"Seq2Seq\",\n", ")" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "artifact = wandb.Artifact(\n", " name=f\"model-{wandb.run.id}\",\n", " type=\"DalleBart_model\",\n", " metadata={\"embeddings\": \"reset\"},\n", ")\n", "\n", "for filename in [\"config.json\", \"flax_model.msgpack\"]:\n", " artifact.add_file(f\"{Path(checkpoint_dir) / filename}\")" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wandb.run.log_artifact(artifact)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Waiting for W&B process to finish... (success)." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(Label(value='1670.207 MB of 1670.207 MB uploaded (0.000 MB deduped)\\r'), FloatProgress(value=1.…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Synced astral-durian-2957: https://wandb.ai/dalle-mini/dalle-mini/runs/2v9szi3q
Synced 5 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Find logs at: ./wandb/run-20220722_105625-2v9szi3q/logs" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "wandb.finish()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "----" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "include_colab_link": true, "machine_shape": "hm", "name": "DALL·E mini - Inference pipeline.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: tools/train/scalable_shampoo/README.md ================================================ # Notes Files copied from [google-research/scalable_shampoo/optax](https://github.com/google-research/google-research/tree/master/scalable_shampoo/optax). Imports have been modified to be relative. This will eventually be replaced with `optax-shampoo` package. ================================================ FILE: tools/train/scalable_shampoo/distributed_shampoo.py ================================================ # coding=utf-8 # Copyright 2022 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # An implementation of distributed Shampoo optimizer from: # # Scalable Second Order Optimization for Deep Learning # Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer # Preprint Paper: https://arxiv.org/abs/2002.09018 # # This implementation moves computation of inverse pth root back to the # accelerator (if higher precision is available). # # Authors: Rohan Anil (rohananil at google dot com) # Vineet Gupta (vineet at google dot com) # James Lottes (jlottes at google dot com) # Anudhyan Boral (anudhyan at google dot com) # """Distributed Shampoo Implementation.""" import enum import functools import itertools from typing import Any, Callable, List, NamedTuple, Optional, Tuple, Union import chex import jax import jax.numpy as jnp import numpy as np import optax from absl import logging from flax import struct from jax import lax from jax.experimental import pjit from jax.experimental.sparse import linalg from .quantization_utils import QuantizedValue from .symmetric_matrices import symmetric_matrices # Dtype for inverse-pth root routine # Switch to f64 if you have hardware that supports it. Enable the jax flag # jax_enable_x64 for this to work, otherwise it will default to float32. _MAT_INV_PTH_ROOT_DTYPE = jnp.float64 @struct.dataclass class TrainingMetrics: inverse_pth_root_errors: chex.Array # Error for inverse-pth roots. # TODO(rohananil): Add more important metrics to track during training. # Per parameter optimizer state used in data-parallel training. class ParameterStats(NamedTuple): """State associated to each parameter of the model being trained.""" diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner statistics: List[Any] # Statistics (QuantizedValue, chex.Array) preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array) diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner momentum: QuantizedValue # Momentum for the shampoo preconditioner training_metrics: TrainingMetrics # Metrics (optional for training). # For training extremely large model; We keep a global state with a concatenated # statistics and preconditioner states for all vars. This is so that we can # annotate the leading axis to be sharded to save memory at the cost of # communication. @struct.dataclass class GlobalShardedParameterStats: statistics: chex.Array # Statistics preconditioners: chex.Array # Preconditioners exponents: chex.Array # exponents # These are per-parameter local states; All statistics here mirror the parameter # Thus the sharding is copied over from the param specification. @struct.dataclass class LocalShardedParameterStats: """State associated to each parameter of the model being trained.""" diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner momentum: QuantizedValue # Momentum for the shampoo preconditioner training_metrics: TrainingMetrics # Metrics (optional for training). index_start: np.int32 = struct.field( pytree_node=False ) # Index into global statistics array sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics. def init_training_metrics(num_statistics): # Since the downstream apis expect a jnp.array - we create a dummy one if # num_statistics=0. if not num_statistics: return TrainingMetrics(jnp.array(0, jnp.float32)) else: return TrainingMetrics(jnp.zeros([num_statistics], jnp.float32)) def init_training_metrics_shapes(num_statistics): # Since the downstream apis expect a jnp.array - we create a dummy one if # num_statistics=0. if not num_statistics: return TrainingMetrics([[], jnp.float32]) else: return TrainingMetrics([[num_statistics], jnp.float32]) def init_training_metrics_pspec(): return TrainingMetrics(pjit.PartitionSpec()) class ShardedShampooStats(NamedTuple): """Shampoo state in sharded mode.""" global_stats: Any local_stats: Any class ShampooState(NamedTuple): count: chex.Array stats: Any class InitFnState(NamedTuple): init_fn: Any pspec_fn: Any shape_and_dtype_fn: Any class GraftingType(enum.IntEnum): SGD = 1 ADAGRAD = 2 RMSPROP = 3 RMSPROP_NORMALIZED = 4 SQRT_N = 5 ADAGRAD_NORMALIZED = 6 class PreconditionerType(enum.IntEnum): # Default, computes preconditioner for each dim ALL = 1 # One sided Shampoo, in this cases only on input dim. # Assumes last dim is always the output dim and everything else input dim. INPUT = 2 def power_iteration( matrix, num_iters=100, error_tolerance=1e-6, precision=lax.Precision.HIGHEST, ): r"""Power iteration algorithm. The power iteration algorithm takes a symmetric PSD matrix `A`, and produces a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue of `A`, and a vector v, which is the corresponding eigenvector of `A`. References: [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration) Args: matrix: the symmetric PSD matrix. num_iters: Number of iterations. error_tolerance: Iterative exit condition. precision: precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise) b) lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST (best possible precision, slowest) Returns: eigen vector, eigen value """ matrix_size = matrix.shape[-1] def _iter_condition(state): i, unused_v, unused_s, unused_s_v, run_step = state return jnp.logical_and(i < num_iters, run_step) def _iter_body(state): """One step of power iteration.""" i, new_v, s, s_v, unused_run_step = state new_v = new_v / jnp.linalg.norm(new_v) s_v = jnp.einsum("ij,j->i", matrix, new_v, precision=precision) s_new = jnp.einsum("i,i->", new_v, s_v, precision=precision) return ( i + 1, s_v, s_new, s_v, jnp.greater(jnp.abs(s_new - s), error_tolerance), ) # Figure out how to use step as seed for random. v_0 = ( np.random.RandomState(1729).uniform(-1.0, 1.0, matrix_size).astype(matrix.dtype) ) init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True]) _, v_out, s_out, _, _ = lax.while_loop(_iter_condition, _iter_body, init_state) v_out = v_out / jnp.linalg.norm(v_out) return v_out, s_out def mat_power( mat_m, p, precision=lax.Precision.HIGHEST, ): """A simple matrix power method. M^p where p can be TracedValue.""" power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE) def _iter_condition(state): i, _, _ = state return i > 0 def _iter_body(state): i, power, mat = state power = jax.lax.cond( i % 2 == 1, lambda: jnp.matmul(mat, power, precision=precision), lambda: power, ) i //= 2 mat = jnp.matmul(mat, mat, precision=precision) return i, power, mat _, result, _ = lax.while_loop(_iter_condition, _iter_body, (p, power, mat_m)) return result def _pth_root_difference(w, alpha, beta, p): """Computes (w+alpha)^(-1/p)-(w+beta)^(-1/p).""" a = w + alpha b = w + beta a_minus_b = alpha - beta exp = -1 / p def _stable_subtract(b, a_minus_b): # Mathematically identical to the target expression, with (w+beta)^(-1/p) # term factored out and w cancellation in the subtraction. return (b**exp) * jnp.expm1(exp * jnp.log1p(a_minus_b / b)) return jnp.where( # Choose the branch with the best log1p approximation. jnp.abs(a_minus_b / b) < jnp.abs(a_minus_b / a), -_stable_subtract(a, -a_minus_b), _stable_subtract(b, a_minus_b), ) def matrix_inverse_pth_root( matrix, p, num_iters=100, ridge_epsilon=1e-6, error_tolerance=1e-6, precision=lax.Precision.HIGHEST, relative_matrix_epsilon=True, lobpcg_topk_precondition=0, lobpcg_max_iter=0, ): """Computes `matrix^(-1/p)`, where `p` is a positive integer. This function uses the Coupled newton iterations algorithm for the computation of a matrix's inverse pth root. References: [Functions of Matrices, Theory and Computation, Nicholas J Higham, Pg 184, Eq 7.18]( https://epubs.siam.org/doi/book/10.1137/1.9780898717778) Args: matrix: the symmetric PSD matrix whose power it to be computed p: exponent, for p a positive integer. num_iters: Maximum number of iterations. ridge_epsilon: Ridge epsilon added to make the matrix positive definite. error_tolerance: Error indicator, useful for early termination. precision: precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise) b) lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST (best possible precision, slowest) relative_matrix_epsilon: Whether to use relative epsilon to the max eigen value when computing inverse-pth root. lobpcg_topk_precondition: If nonzero, specifies the number of top eigenvectors to subtract out before performing LOBPCG. Note this makes relative_matrix_epsilon essentially free. lobpcg_max_iter: Maximum iteration count for LOBPCG, defaults to `lobpcg_topk_precondition`. Returns: matrix^(-1/p) and the error """ # If the input is not square, materialize it from the concatenated form. if matrix.shape[0] != matrix.shape[1]: matrix = symmetric_matrices.materialize_matrix_from_concat(matrix) assert matrix.shape[0] == matrix.shape[1] # We use _MAT_INV_PTH_ROOT_DTYPE for the matrix inverse pth root. # Switch to f64 if you have hardware that supports it. Enable the jax flag # jax_enable_x64 for this to work. matrix_size = matrix.shape[0] orig_dtype = matrix.dtype matrix = matrix.astype(_MAT_INV_PTH_ROOT_DTYPE) alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE) identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE) original_matrix = matrix if lobpcg_topk_precondition > 0: # TODO(vladf): reuse previous top-k as the initial search directions pad_shape = (matrix_size - lobpcg_topk_precondition, lobpcg_topk_precondition) search_dirs = jnp.concatenate( (jnp.eye(lobpcg_topk_precondition), jnp.zeros(pad_shape)), axis=0 ) eigvals, eigvecs, actual_iters = linalg.lobpcg_standard( matrix, search_dirs, lobpcg_topk_precondition if lobpcg_max_iter == 0 else lobpcg_max_iter, ) del actual_iters # TODO(vladf): return diagnostics dictionary # The minimal eigenvalue among top-k becomes the maximal one in the whole # matrix after deflation. max_ev = jnp.min(eigvals) deflation = eigvals - max_ev scaled_vecs = eigvecs * jnp.sqrt(deflation) # Deflate out top eigenvectors to reduce matrix condition number. matrix -= scaled_vecs.dot(scaled_vecs.T, precision=jax.lax.Precision.HIGHEST) # Only use power iteration if lobpcg wasn't already used to derive the # top eigenvalue. elif relative_matrix_epsilon: _, max_ev = power_iteration( matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision ) eigvals, eigvecs = None, None # Unused but required by pytype. # Use absolute matrix epsilon scaling otherwise. else: max_ev = 1.0 eigvals, eigvecs = None, None # Unused but required by pytype. ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-6) def _iter_condition(state): (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, run_step) = state error_above_threshold = jnp.logical_and(error > error_tolerance, run_step) return jnp.logical_and(i < num_iters, error_above_threshold) def _iter_body(state): (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state mat_m_i = (1 - alpha) * identity + alpha * mat_m new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision) new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision) new_error = jnp.max(jnp.abs(new_mat_m - identity)) # sometimes error increases after an iteration before decreasing and # converging. 1.2 factor is used to bound the maximal allowed increase. return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, new_error < error * 1.2) if matrix_size == 1: resultant_mat_h = (matrix + ridge_epsilon) ** alpha error = jnp.array(0, jnp.float32) else: damped_matrix = matrix + ridge_epsilon * identity z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix)) new_mat_m_0 = damped_matrix * z new_error = jnp.max(jnp.abs(new_mat_m_0 - identity)) new_mat_h_0 = identity * jnp.power(z, 1.0 / p) init_state = tuple([0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True]) _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop( _iter_condition, _iter_body, init_state ) error = jnp.max(jnp.abs(mat_m - identity)).astype(jnp.float32) is_converged = jnp.asarray(convergence, old_mat_h.dtype) resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h resultant_mat_h = jnp.asarray(resultant_mat_h, orig_dtype) if lobpcg_topk_precondition > 0: # Since we deflated the top eigenvectors prior to p-th root inverse, # the resultant matrix has larger eigenvalues associated with those # same eigenvectors, which we need to now re-deflate. # # Note that _pth_root_difference returns positive values for this # particular argument ordering as min(eigvals) <= eigvals for the # jnp.sqrt below. pth_diff = _pth_root_difference(ridge_epsilon, jnp.min(eigvals), eigvals, p) scaled_vecs = eigvecs * jnp.sqrt(pth_diff) resultant_mat_h = ( resultant_mat_h.astype(scaled_vecs.dtype) - scaled_vecs.dot(scaled_vecs.T, precision=jax.lax.Precision.HIGHEST) ).astype(orig_dtype) mat_m = jnp.matmul( mat_power(resultant_mat_h, p), original_matrix, precision=jax.lax.Precision.HIGHEST, ) error = jnp.max(jnp.abs(mat_m - identity)).astype(jnp.float32) return resultant_mat_h, error def merge_small_dims(shape_to_merge, max_dim): """Merge small dimensions. If there are some small dimensions, we collapse them: e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024 [1, 2, 768, 1, 2048] --> [2, 768, 2048] Args: shape_to_merge: Shape to merge small dimensions. max_dim: Maximal dimension of output shape used in merging. Returns: Merged shape. """ if shape_to_merge and np.all(np.array(shape_to_merge) == 1): return [1] resulting_shape = [] product = 1 for d in shape_to_merge: if product * d <= max_dim: product *= d else: if product > 1: resulting_shape.append(product) product = d if product > 1: resulting_shape.append(product) return resulting_shape def pad_square_matrix(mat, max_size): """Pad a square matrix up to max_size. Args: mat: a matrix to pad. max_size: matrix size requested. Returns: Given M returns [[M, 0], [0, I]] """ rows, cols = mat.shape if rows != cols: raise ValueError( f"Must have rows == cols, instead got rows={rows}, cols={cols}" ) if cols > max_size: raise ValueError( f"Must have cols <= max_size. Instead got cols={cols}, max_size={max_size}." ) if rows == max_size: return mat pad_size = max_size - rows zs1 = jnp.zeros([rows, pad_size], dtype=mat.dtype) zs2 = jnp.zeros([pad_size, rows], dtype=mat.dtype) eye = jnp.eye(pad_size, dtype=mat.dtype) mat = jnp.concatenate([mat, zs1], 1) mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0) return mat def make_sliced_padding( symmetric_block_size, num_blocks, starting_block, dtype, ): """Returns padding for symmetric block matrix. Specifically, the padding is given concatenated rectangular matrices representing the lower-triangular rows below the starting block. For example, if we want to pad the symmetric matrix M = [[A, B^T] [B, C]], the desired output (in terms of the full matrix) with num_blocks = 4 is M_padded = [[A, B^T, 0, 0] [B, C, 0, 0] [0, 0, I, 0] 0, 0, 0, I]. We would represent M as the block matrix mat = [A, B, C]. In this form, the additional padding to provide has form [0, 0, I, 0, 0, 0, I] (only the lower triangular parts in the third and fourth rows). Args: symmetric_block_size: The size of each block. num_blocks: The total number of blocks. starting_block: The block where to start the padding. dtype: The type to use for the blocks. """ if starting_block == num_blocks: return jnp.zeros(shape=(symmetric_block_size, 0), dtype=dtype) blocks = [] for i in range(starting_block, num_blocks): blocks.append( jnp.zeros( shape=(symmetric_block_size, symmetric_block_size * i), dtype=dtype ) ) blocks.append(jnp.eye(symmetric_block_size, dtype=dtype)) return jnp.concatenate(blocks, axis=-1) def pad_block_symmetric_matrix( mat, symmetric_block_size, max_num_blocks, ): """Returns the padded blocked symmetric matrix. The size of the padded matrix will be: [symmetric_block_size, symmetric_block_size * max_num_blocks] The input matrix can either: - Be square with size less or equal to symmetric_block_size. In this case, mat will first be padded to a square matrix of size symmetric_block_size, and then be padded again up to the full size of the blocked matrix. - Be a rectangle with number of rows equal to block size. In this case, number of columns must be a multiple of number of rows, and the ratio must correspond to a block representation of a symmetric matrix. That is, the ratio must have form x * (x + 1) / 2. Here, x represents the number of block rows represented by the matrix. Args: mat: The input block matrix. symmetric_block_size: The size of blocks. max_num_blocks: The largest number of blocks to pad to. """ rows, cols = mat.shape if rows > symmetric_block_size: raise ValueError( "Must have rows <= symmetric_block_size. Instead got " f"rows={rows}, symmetric_block_size={symmetric_block_size}." ) if rows > cols: raise ValueError( f"Must have rows <= cols, instead got rows={rows}, cols={cols}." ) if cols > symmetric_block_size * max_num_blocks: raise ValueError( "Must have cols <= symmetric_block_size * max_num_blocks " f"Instead got cols={cols}, " f"symmetric_block_size={symmetric_block_size}, " f"max_num_blocks={max_num_blocks}." ) if rows < symmetric_block_size: mat = pad_square_matrix(mat, max_size=symmetric_block_size) # Update rows and cols after possibly padding in pad_square_matrix. rows, cols = mat.shape assert rows == symmetric_block_size assert cols % rows == 0 filled_blocks = cols // rows padding_blocks = make_sliced_padding( symmetric_block_size=symmetric_block_size, num_blocks=symmetric_matrices.num_blocks_from_total_blocks(max_num_blocks), starting_block=symmetric_matrices.num_blocks_from_total_blocks(filled_blocks), dtype=mat.dtype, ) return jnp.concatenate([mat, padding_blocks], axis=-1) def pad_vector(vec, max_size): """Pad a vector to a max_size. Args: vec: a vector to pad. max_size: matrix size requested. Returns: Given V returns [V, 0] """ size = vec.shape[0] assert size <= max_size if size == max_size: return vec pad_size = max_size - size zs1 = jnp.zeros([pad_size], dtype=vec.dtype) return jnp.concatenate([vec, zs1], 0) def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs): """Avoids wasteful buffer allocation with XLA.""" def _iter_body(unused_state): results = compute_fn(*args, **kwargs) return tuple([False] + list(results)) def _iter_condition(state): return state[0] results = jax.lax.while_loop( _iter_condition, _iter_body, tuple([predicate] + init_state) ) return tuple(results[1:]) class BlockPartitioner: """Partitions a tensor into smaller tensors.""" def __init__(self, param, block_size): self._shape = param.shape self._splits = [] split_sizes = [] # We split params into smaller blocks. Here we store the metadata to make # that split. for i, d in enumerate(param.shape): if 0 < block_size < d: # d-1, otherwise split appends a 0-size array. nsplit = (d - 1) // block_size indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size sizes[-1] = d - indices[-1] self._splits.append((i, indices)) split_sizes.append(sizes) else: split_sizes.append(np.array([d], dtype=np.int32)) self._split_sizes = split_sizes def split_sizes(self): return self._split_sizes def partition(self, tensor): """Partition tensor into blocks.""" assert tensor.shape == self._shape tensors = [tensor] for i, indices in self._splits: tensors_local = [] for t in tensors: tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i)) tensors = tensors_local return tensors def merge_partitions(self, partitions): """Merge partitions back to original shape.""" for i, indices in reversed(self._splits): n = len(indices) + 1 partial_merged_tensors = [] ind = 0 while ind < len(partitions): partial_merged_tensors.append( jnp.concatenate(partitions[ind : ind + n], axis=i) ) ind += n partitions = partial_merged_tensors assert len(partitions) == 1 return partitions[0] def gram_weighted_update(old_stats, g, axis, w1, w2, precision=None): """Updated statistics via weighted average with new Gram matrix. Returns w₁ R + w₂ Gᵀ G where R is `old_stats` and G is the matrix whose columns are the flattened slices of the tensor `g` along the given `axis`. (So, `old_stats` and the returned matrix have dimensions n x n where n = `g.shape[axis]`). Args: old_stats: Old statistics. g: Gradient tensor. axis: Axis along which to slice `g`. w1: Scalar weight for old statistics. w2: Scalar weight for new Gram matrix. precision: Optional precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise) b) lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST (best possible precision, slowest) Returns: Weighted average of old and new statistics. """ axes = [i for i in range(g.ndim) if i != axis] gram_matrix = jnp.tensordot(g, g, axes=(axes, axes), precision=precision) return w1 * old_stats + w2 * gram_matrix class Preconditioner: """Compute statistics/shape from gradients for preconditioning.""" def __init__( self, param, block_size, merge_small_dims_block_size, best_effort_shape_interpretation, preconditioner_type=PreconditionerType.ALL, ): """Initializes the preconditioner. Args: param: parameter to precondition. block_size: Block size used to split param. merge_small_dims_block_size: Block size for merging dims. best_effort_shape_interpretation: Whether to collapse/merge dims together. preconditioner_type: Type of preconditioner to use. """ self._original_shape = param.shape self._transformed_shape = param.shape if best_effort_shape_interpretation: self._transformed_shape = merge_small_dims( self._original_shape, merge_small_dims_block_size ) reshaped_param = jnp.reshape(param, self._transformed_shape) self._partitioner = BlockPartitioner(reshaped_param, block_size) self._preconditioner_type = preconditioner_type def updated_statistics_from_grad( self, stats, grad, w1, w2, to_float=None, from_float=None, precision=None, ): """Update statistics from gradients. Args: stats: Old statistics or its Cholesky factor if `cholesky` is True. grad: Gradient to compute statistics from. w1: Weight for old statistics. w2: Weight for new statistics. to_float: Optional function for converting stats to floating point. from_float: Optional function for converting from floating point. precision: Optional precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise) b) lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST (best possible precision, slowest) Returns: A list of updated gradient statistics for each partition. """ to_float = to_float if to_float is not None else (lambda x: x) from_float = from_float if from_float is not None else (lambda x: x) update = functools.partial(gram_weighted_update, precision=precision) reshaped_grad = jnp.reshape(grad, self._transformed_shape) partitioned_grads = self._partitioner.partition(reshaped_grad) new_stats = [] index = 0 for g in partitioned_grads: should_preconditioned_dims = self.should_precondition_dims() num_preconditioners = sum(should_preconditioned_dims) for axis in range(num_preconditioners): new_stat = update(to_float(stats[index]), g, axis, w1, w2) new_stats.append(from_float(new_stat)) index += 1 return new_stats def should_precondition_dims(self): """A vector containing indicator indicating if the dim is preconditioned.""" split_sizes = self._partitioner.split_sizes() rank = len(split_sizes) if self._preconditioner_type == PreconditionerType.ALL or rank <= 1: return [True] * rank else: return [True] * (rank - 1) + [False] def shapes_for_preconditioners(self): """Returns shape from statistics.""" split_sizes = self._partitioner.split_sizes() rank = len(split_sizes) # We ignore preconditioner types if rank == 1 preconditioner_shapes = [] for t in itertools.product(*split_sizes): if self._preconditioner_type == PreconditionerType.ALL or rank <= 1: preconditioner_shapes.extend([[d, d] for d in t]) else: preconditioner_shapes.extend([[d, d] for d in t[:-1]]) return preconditioner_shapes def exponent_for_preconditioner(self): """Returns exponent to use for inverse-pth root M^{-1/p}.""" should_preconditioned_dims = self.should_precondition_dims() num_preconditioners = sum(should_preconditioned_dims) return 2 * num_preconditioners def preconditioned_grad(self, grad, preconditioners): """Precondition the gradient. Args: grad: A gradient tensor to precondition. preconditioners: A list of preconditioners to apply. Returns: A preconditioned gradient. """ reshaped_grad = jnp.reshape(grad, self._transformed_shape) partitioned_grads = self._partitioner.partition(reshaped_grad) preconditioned_partitioned_grads = [] for i, g in enumerate(partitioned_grads): should_preconditioned_dims = self.should_precondition_dims() num_preconditioners = sum(should_preconditioned_dims) preconditioners_for_grad = preconditioners[ i * num_preconditioners : (i + 1) * num_preconditioners ] precond_g = g rank = len(g.shape) for j, precondition in enumerate(should_preconditioned_dims): if precondition: precond_g = jnp.tensordot( precond_g, preconditioners_for_grad[j], axes=[[0], [0]] ) else: precond_g = jnp.transpose(precond_g, axes=(*range(1, rank), 0)) preconditioned_partitioned_grads.append(precond_g) merged_grad = self._partitioner.merge_partitions( preconditioned_partitioned_grads ) return jnp.reshape(merged_grad, self._original_shape) def _convert_to_parameter_stats(global_stats, local_stat, convert_statistics=True): """Creates parameter stats from sharded stats.""" index_start = int(local_stat.index_start) index_end = int(len(local_stat.sizes)) + index_start statistics = global_stats.statistics[index_start:index_end, :, :] preconditioners = global_stats.preconditioners[index_start:index_end, :, :] new_statistics = [] new_preconditioners = [] for i, size in enumerate(local_stat.sizes): new_statistics.append(statistics[i][:size, :size]) new_preconditioners.append(preconditioners[i][:size, :size]) if not convert_statistics: new_statistics = None return ParameterStats( local_stat.diagonal_statistics, new_statistics, new_preconditioners, local_stat.diagonal_momentum, local_stat.momentum, local_stat.training_metrics, ) def _convert_from_parameter_stats(parameter_stats, local_stats): """Creates sharded stats from paramter stats.""" return LocalShardedParameterStats( parameter_stats.diagonal_statistics, parameter_stats.diagonal_momentum, parameter_stats.momentum, parameter_stats.training_metrics, local_stats.index_start, local_stats.sizes, ) def _add_error_into_local_stats(local_stats, errors, inverse_failure_threshold): """Adds errors back into local statistics.""" new_local_stats = [] for local_stat in local_stats: if local_stat.sizes: index_start = int(local_stat.index_start) index_end = int(len(local_stat.sizes)) + index_start per_stat_error = errors[index_start:index_end] else: per_stat_error = jnp.array(0, jnp.float32) if local_stat.sizes: per_stat_error = jnp.where( jnp.logical_and( per_stat_error > 0.0, per_stat_error != inverse_failure_threshold ), per_stat_error, local_stat.training_metrics.inverse_pth_root_errors, ) new_local_stats.append( LocalShardedParameterStats( local_stat.diagonal_statistics, local_stat.diagonal_momentum, local_stat.momentum, TrainingMetrics(per_stat_error), local_stat.index_start, local_stat.sizes, ) ) return new_local_stats def batch(x, num_devices): """Batch `x` so that so that leading axis is num_devices.""" n = len(x) b = int(n / num_devices) return jnp.stack([jnp.stack(x[idx : idx + b]) for idx in range(0, n, b)]) def unbatch(batched_values): """Unbatch values across leading axis and return a list of elements.""" b1, b2 = batched_values.shape[0], batched_values.shape[1] results = [] for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0): v_array = jnp.squeeze(v_array) # b2 = batches (number of preconditioner computation) per core. if b2 > 1: for v in jnp.split(v_array, indices_or_sections=b2, axis=0): results.append(jnp.squeeze(v)) else: results.append(v_array) return results def distributed_shampoo( learning_rate, block_size, beta1=0.9, beta2=0.999, diagonal_epsilon=1e-10, matrix_epsilon=1e-6, weight_decay=0.0, start_preconditioning_step=5, preconditioning_compute_steps=1, statistics_compute_steps=1, best_effort_shape_interpretation=True, graft_type=GraftingType.SGD, nesterov=True, exponent_override=0, # Pass pmap 'batch axis name' in pmap mode. batch_axis_name=None, ### Only set following 3 params in pjit/spmd mode. ### WARNING: Experimental statistics_partition_spec=None, preconditioner_partition_spec=None, num_devices_for_pjit=None, shard_optimizer_states=False, ### ### Experimental memory reduction mode best_effort_memory_usage_reduction=False, ### inverse_failure_threshold=0.1, moving_average_for_momentum=False, skip_preconditioning_dim_size_gt=4096, clip_by_scaled_gradient_norm=None, precision=lax.Precision.HIGHEST, tensordot_precision=None, relative_matrix_epsilon=True, merge_small_dims_block_size=4096, lobpcg_topk_precondition=0, lobpcg_max_iter=0, precondtioner_type=PreconditionerType.ALL, skip_preconditioning_rank_lt=1, decoupled_learning_rate=True, decoupled_weight_decay=False, ): """Distributed Shampoo optimizer. Distributed Shampoo is a second-order preconditioned method (concretely, a variant of full-matrix Adagrad), that provides significant convergence and wall-clock time improvements compared to conventional first-order methods, and that has been shown to scale to large state-of-the-art deep learning models. References: Scalable Second Order Optimization for Deep Learning, Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer Preprint: https://arxiv.org/abs/2002.09018 Args: learning_rate: the step size used to update the parameters. block_size: Block size for large layers (if > 0). Preconditioning compute operation is cubic in the dimension of the tensor. Block size allows us to chunk the layers into sub-layers of maximal dimension dictated by this value. Use 128 as default (increase if you have compute budget). beta1: momentum parameter. beta2: second moment averaging parameter. diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting to AdaGrad is enabled). matrix_epsilon: epsilon to add to statistics before computing inverse pth root. If you are running in f32 precision for inverse pth root (recommended today) this can go upto 1e-6. If you have latest hardware with native f64 precision, set this upto 1e-12. weight_decay: Weight decay for regularization. start_preconditioning_step: When to start Shampoo update before which diagonal update is used. This is because we dont have enough information to do stable inverse. preconditioning_compute_steps: How often to compute preconditioner. Performance tuning params for controlling memory and compute requirements. Ideally set this and statistics_compute_steps params to 1. statistics_compute_steps: How often to compute statistics. best_effort_shape_interpretation: If there are some small dimensions, collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048] graft_type: Grafting is a technique to fix the layerwise scale of Shampoo optimizer. This allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad is already well tuned. nesterov: Nesterov momentum. exponent_override: Override the exponent used in matrix inverse. batch_axis_name: labeled axis over pmap for data-parallel training the optimizer used for. statistics_partition_spec: PartitionSpec to be used in sharded mode. preconditioner_partition_spec: PartitionSpec to be used in sharded mode. num_devices_for_pjit: Number of devices to parallelize over when using pjit. shard_optimizer_states: Shard optimizer states to save memory in model parallel training. best_effort_memory_usage_reduction: Best effort memory usage reduction. - diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x) -> jnp.int8 - statistics, preconditioners -> jnp.int16 + diagonals inverse_failure_threshold: numerics are hard and inverses fail sometimes; we determine that using this threshold. moving_average_for_momentum: Whether to use moving average for momentum instead of exponential moving average. skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is greater than this value. clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful when using RMSProp Grafting). precision: precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise) b) lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST (best possible precision, slowest) tensordot_precision: Optional precision to use for the tensordot operation when computing statistics (e.g., G Gᵀ). Same options as `precision` above. relative_matrix_epsilon: Whether to use relative epsilon to the max eigen value when computing inverse-pth root. merge_small_dims_block_size: Used as the maximum block size to merge the shapes. lobpcg_topk_precondition: If nonzero, specifies the number of top eigenvectors to subtract out before performing LOBPCG. Note this makes relative_matrix_epsilon essentially free. lobpcg_max_iter: Number of LOBPCG iterations, if zero defaults to `lobpcg_topk_precondition`. precondtioner_type: Preconditioner type to select all, left only or right only preconditioners. skip_preconditioning_rank_lt: Skips preconditioning for parameters with rank less than this value. decoupled_learning_rate: If True, use decoupled learning rate, otherwise couple it with preconditioned gradient computation. (Default True) decoupled_weight_decay: If True, use decoupled weight decay, otherwise couple with weight decay. (Default False) Returns: a GradientTransformation. """ def _graft_type_has_diagonal_statistics(): """Returns True if using diagonal firt order method for grafting.""" return graft_type != GraftingType.SGD and graft_type != GraftingType.SQRT_N def quantized_dtype_for_momentum_buffers(var): return ( jnp.int8 if best_effort_memory_usage_reduction and len(var.shape) > 1 else jnp.float32 ) # Preconditioner and statistics are both stores as int16 in this mode. # We take out the diagonal to make quantization easier. def quantized_dtype_for_second_moment_statistics_buffers(): return ( jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32 ) # Preconditioner and statistics are both stores as int16 in this mode. # We take out the diagonal to make quantization easier. def quantized_dtype_for_second_moment_preconditioner_buffers(): return ( jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32 ) def _to_float(maybe_quantized): if isinstance(maybe_quantized, QuantizedValue): return maybe_quantized.to_float() else: return maybe_quantized def _maybe_quantize_statistics(statistics_list): return _maybe_quantize_matrices_with_dtype( statistics_list, quantized_dtype_for_second_moment_statistics_buffers() ) def _maybe_quantize_preconditioners(statistics_list): return _maybe_quantize_matrices_with_dtype( statistics_list, quantized_dtype_for_second_moment_preconditioner_buffers() ) def _maybe_quantize_matrices_with_dtype(statistics_list, quantized_dtype): if quantized_dtype != jnp.float32: return [ QuantizedValue.from_float_value( s, quantized_dtype, extract_diagonal=True ) for s in statistics_list ] else: return statistics_list def _maybe_dequantize_preconditioners(preconditioner_list): return _maybe_dequantize_matrices_with_dtype( preconditioner_list, quantized_dtype_for_second_moment_preconditioner_buffers(), ) def _maybe_dequantize_matrices_with_dtype(statistics_list, quantized_dtype): if quantized_dtype != jnp.float32: return [s.to_float() for s in statistics_list] else: return statistics_list def _quantize_diagonal_statistics(diagonal_statistics): return QuantizedValue.from_float_value(diagonal_statistics, jnp.float32) def _quantize_momentum(momentum_statistics): return QuantizedValue.from_float_value( momentum_statistics, quantized_dtype_for_momentum_buffers(momentum_statistics), ) def preconditioner_from_params(param): """Returns a Preconditioner object for given param.""" return Preconditioner( param, block_size, merge_small_dims_block_size, best_effort_shape_interpretation, precondtioner_type, ) def sharded_init_fn(params): """Returns optimizer state (for PJIT mode). Args: params: the parameters that should be updated. """ params_flat, treedef = jax.tree_flatten(params) # Find max size to pad to. max_size = 0 for param in params_flat: preconditioner = preconditioner_from_params(param) if not _skip_preconditioning(param): shapes = preconditioner.shapes_for_preconditioners() sizes = [s[0] for s in shapes] max_size = max(max(sizes), max_size) padded_statistics = [] padded_preconditioners = [] local_stats_flat = [] exponents = [] for param in params_flat: preconditioner = preconditioner_from_params(param) shapes = preconditioner.shapes_for_preconditioners() sizes = [] statistics = [] preconditioners = [] index_start = len(padded_statistics) if not _skip_preconditioning(param): sizes = [s[0] for s in shapes] shapes = preconditioner.shapes_for_preconditioners() statistics = [ matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32) for s in shapes ] preconditioners = [jnp.eye(max_size, dtype=jnp.float32) for s in shapes] padded_statistics.extend(statistics) padded_preconditioners.extend(preconditioners) exponent = ( preconditioner.exponent_for_preconditioner() if exponent_override == 0 else exponent_override ) exponents.extend([exponent] * len(shapes)) diagonal_statistics = _quantize_diagonal_statistics(jnp.zeros_like(param)) diagonal_momentum = _quantize_momentum(jnp.zeros_like(param)) momentum = _quantize_momentum(jnp.zeros_like(param)) local_stats_flat.append( LocalShardedParameterStats( diagonal_statistics, diagonal_momentum, momentum, init_training_metrics(len(sizes)), index_start, sizes, ) ) local_stats = jax.tree_unflatten(treedef, local_stats_flat) to_pad = -len(padded_statistics) % num_devices_for_pjit if max_size == 0: to_pad = num_devices_for_pjit max_size = block_size stat_dtype = jnp.float32 else: stat_dtype = padded_statistics[0].dtype # Pad the statistics and preconditioner matrices to be a multiple of # num devices. # TODO(rohananil): Relax to only the size of the mesh axis where the dim # is split on. padded_statistics.extend( [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)] ) padded_preconditioners.extend( [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)] ) exponents.extend([1 for _ in range(to_pad)]) global_stats = GlobalShardedParameterStats( jnp.stack(padded_statistics), jnp.stack(padded_preconditioners), jnp.stack(exponents), ) return ShampooState( count=jnp.zeros([], jnp.int32), stats=ShardedShampooStats(global_stats, local_stats), ) def _max_statistics_size_from_params(params): max_size = 0 for param in params: param_clone = jnp.zeros(param.shape, dtype=param.dtype) preconditioner = preconditioner_from_params(param_clone) if not _skip_preconditioning(param): shapes = preconditioner.shapes_for_preconditioners() sizes = [s[0] for s in shapes] max_size = max(max(sizes), max_size) return max_size def _remove_leading_sharding_annotation(pspec): """Mapping from N-d to (N-1)-d, used for quantization, factoring etc.""" # None and PSpec(None) are valid PSpecs. if pspec and len(pspec) > 1: return pjit.PartitionSpec(*pspec[1:]) else: return [] def sharded_init_partition_spec_fn( params, params_partition_spec, partition_spec_for_statistics ): """Returns a parallel state tree with PartitionSpec associated with state. Args: params: A pytree with params. params_partition_spec: A pytree with PartitionSpec for params. partition_spec_for_statistics: PartitionSpec for the statistics. """ # Parallel lists of spec, and params. param_pspec_flat, _ = jax.tree_flatten( params_partition_spec, is_leaf=lambda x: x is None ) params_flat, treedef = jax.tree_flatten(params) assert param_pspec_flat assert params_flat # Step is replicated across cores. # None means cores. local_stats_flat = [] num_statistics = 0 for param, param_pspec in zip(params_flat, param_pspec_flat): param_clone = jnp.zeros(param.shape, dtype=param.dtype) preconditioner = preconditioner_from_params(param_clone) shapes = preconditioner.shapes_for_preconditioners() sizes = [] index_start = num_statistics if not _skip_preconditioning(param): sizes = [s[0] for s in shapes] shapes = preconditioner.shapes_for_preconditioners() num_statistics += len(shapes) qdtype = quantized_dtype_for_momentum_buffers(param) m1_pspec = param_pspec m2_pspec = param_pspec m1_scale_pspec = [] m2_scale_pspec = [] if qdtype != jnp.float32: m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec) m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec) local_stats_flat.append( LocalShardedParameterStats( QuantizedValue( param_pspec, [], [], jnp.float32, False, list(param.shape) ), QuantizedValue( m1_pspec, [], m1_scale_pspec, qdtype, False, list(param.shape) ), QuantizedValue( m2_pspec, [], m2_scale_pspec, qdtype, False, list(param.shape) ), init_training_metrics_pspec(), index_start, sizes, ) ) local_stats = jax.tree_unflatten(treedef, local_stats_flat) global_stats = GlobalShardedParameterStats( partition_spec_for_statistics, partition_spec_for_statistics, pjit.PartitionSpec(), ) count_pspec = pjit.PartitionSpec() return ShampooState( count=count_pspec, stats=ShardedShampooStats(global_stats, local_stats) ) def sharded_init_shape_and_dtype_fn(params): """Returns a parallel state tree with shape, dtype associated with state. Args: params: A pytree with params. """ # Parallel lists of spec, and params. params_flat, treedef = jax.tree_flatten(params) assert params_flat # Step is replicated across cores. # None means cores. local_stats_flat = [] num_statistics = 0 for param in params_flat: param_clone = jnp.zeros(param.shape, dtype=param.dtype) preconditioner = preconditioner_from_params(param_clone) shapes = preconditioner.shapes_for_preconditioners() sizes = [] index_start = num_statistics if not _skip_preconditioning(param): sizes = [s[0] for s in shapes] shapes = preconditioner.shapes_for_preconditioners() num_statistics += len(shapes) qdtype = quantized_dtype_for_momentum_buffers(param) m1_shape_and_dtype = [list(param.shape), param.dtype] m2_shape_and_dtype = [list(param.shape), param.dtype] m1_scale_shape_and_dtype = [] m2_scale_shape_and_dtype = [] if qdtype != jnp.float32: m1_scale_shape_and_dtype = [list(param.shape)[1:], qdtype] m2_scale_shape_and_dtype = [list(param.shape)[1:], qdtype] diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype] local_stats_flat.append( LocalShardedParameterStats( QuantizedValue( diagonal_statistics_shape_and_dtype, [], [], jnp.float32, False, list(param.shape), ), QuantizedValue( m1_shape_and_dtype, [], m1_scale_shape_and_dtype, qdtype, False, list(param.shape), ), QuantizedValue( m2_shape_and_dtype, [], m2_scale_shape_and_dtype, qdtype, False, list(param.shape), ), init_training_metrics_shapes(len(sizes)), index_start, sizes, ) ) local_stats = jax.tree_unflatten(treedef, local_stats_flat) max_statistics_size = _max_statistics_size_from_params(params_flat) to_pad = -num_statistics % num_devices_for_pjit num_statistics += to_pad if num_statistics == 0: num_statistics = num_devices_for_pjit max_statistics_size = block_size statistics_shape = [num_statistics, max_statistics_size, max_statistics_size] global_stats = GlobalShardedParameterStats( [statistics_shape, jnp.float32], [statistics_shape, jnp.float32], [[num_statistics], jnp.int32], ) return ShampooState( count=[[], jnp.float32], stats=ShardedShampooStats(global_stats, local_stats), ) def sharded_update_fn(grads, state, params): """Transform the input gradient and update all statistics in sharded mode. Args: grads: the gradient tensors for the parameters. state: a named tuple containing the state of the optimizer params: the parameters that should be updated. Returns: A tuple containing the new parameters and the new optimizer state. """ params_flat, treedef = jax.tree_flatten(params) grads_flat = treedef.flatten_up_to(grads) global_stats = state.stats.global_stats local_stats_flat = treedef.flatten_up_to(state.stats.local_stats) stats_flat = [ _convert_to_parameter_stats(global_stats, local_stat) for local_stat in local_stats_flat ] new_stats_flat = jax.tree_map( lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat, stats_flat, params_flat, ) outputs = jax.tree_map( lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat, ) updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) updates = jax.tree_unflatten(treedef, updates_flat) # Create new local_stats new_local_stats_flat = [ _convert_from_parameter_stats(new_stat, local_stat) for new_stat, local_stat in zip(new_stats_flat, local_stats_flat) ] max_size = global_stats.statistics.shape[1] new_padded_statistics = [] for stat in new_stats_flat: new_padded_statistics.extend( [pad_square_matrix(stat, max_size) for stat in stat.statistics] ) # Create global stats # TODO(rohananil): Preconditioner is not updated every step, so cost of # stack/pad can be obviated away. # Pad the statistics and preconditioner matrices to be a multiple of # num devices. # TODO(rohananil): Relax to only the size of the mesh axis where the dim # is split on. to_pad = -len(new_padded_statistics) % num_devices_for_pjit if not new_padded_statistics: to_pad = num_devices_for_pjit stat_dtype = jnp.float32 else: stat_dtype = new_padded_statistics[0].dtype new_padded_statistics.extend( [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)] ) new_stacked_padded_statistics = jnp.stack(new_padded_statistics) new_stacked_padded_statistics = pjit.with_sharding_constraint( new_stacked_padded_statistics, statistics_partition_spec ) def _internal_inverse_pth_root_all(): preconditioners, errors = _matrix_inverse_pth_root_pjit( new_stacked_padded_statistics, global_stats.exponents, statistics_partition_spec, ) return preconditioners, errors if preconditioning_compute_steps == 1: new_preconditioners, errors = _internal_inverse_pth_root_all() else: # Passing statistics instead of preconditioners as they are similarly # shaped tensors. Note statistics will be ignored as we are passing in # a large init value for error. preconditioners_init = new_stacked_padded_statistics n = new_stacked_padded_statistics.shape[0] errors_init = jnp.ones([n], jnp.float32) * inverse_failure_threshold init_state = [preconditioners_init, errors_init] perform_step = state.count % preconditioning_compute_steps == 0 new_preconditioners, errors = efficient_cond( perform_step, _internal_inverse_pth_root_all, init_state ) new_local_stats_flat = _add_error_into_local_stats( new_local_stats_flat, errors, inverse_failure_threshold ) new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat) errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( jnp.isnan(errors), errors >= inverse_failure_threshold ).astype(new_preconditioners.dtype) # TODO(rohananil): Check for numerical instabilities. new_conditional_preconditioners = ( predicate * global_stats.preconditioners + (1.0 - predicate) * new_preconditioners ) new_global_stats = GlobalShardedParameterStats( new_stacked_padded_statistics, new_conditional_preconditioners, global_stats.exponents, ) new_shampoo_state = ShampooState( count=state.count + 1, stats=ShardedShampooStats(new_global_stats, new_local_stats), ) return updates, new_shampoo_state def init_fn(params): """Initialise the optimiser's state.""" def _init(param): preconditioner = preconditioner_from_params(param) statistics = [] preconditioners = [] if not _skip_preconditioning(param): shapes = preconditioner.shapes_for_preconditioners() statistics = [ matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes ] preconditioners = [jnp.eye(s[0], dtype=jnp.float32) for s in shapes] diagonal_statistics = [] if _graft_type_has_diagonal_statistics(): diagonal_statistics = jnp.zeros_like(param) diagonal_momentum = _quantize_momentum(jnp.zeros_like(param)) momentum = _quantize_momentum(jnp.zeros_like(param)) return ParameterStats( _quantize_diagonal_statistics(diagonal_statistics), _maybe_quantize_statistics(statistics), _maybe_quantize_preconditioners(preconditioners), diagonal_momentum, momentum, init_training_metrics(len(statistics)), ) return ShampooState( count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params) ) def _skip_preconditioning(param): return len(param.shape) < skip_preconditioning_rank_lt or any( [s > skip_preconditioning_dim_size_gt for s in param.shape] ) def _compute_stats(grad, state, param, step): """Compute per-parameter statistics.""" preconditioner = preconditioner_from_params(param) new_statistics = [[]] * len(state.statistics) w1 = beta2 w2 = beta2 if beta2 == 1.0 else (1.0 - beta2) if not _skip_preconditioning(param): def compute_updated_statistics(): return preconditioner.updated_statistics_from_grad( state.statistics, grad, w1=w1, w2=w2, to_float=_to_float, from_float=lambda x: _maybe_quantize_statistics([x])[0], precision=tensordot_precision, ) if statistics_compute_steps > 1: perform_step = step % statistics_compute_steps == 0 init_state = state.statistics new_statistics = list( efficient_cond(perform_step, compute_updated_statistics, init_state) ) else: new_statistics = compute_updated_statistics() return ParameterStats( state.diagonal_statistics, new_statistics, state.preconditioners, state.diagonal_momentum, state.momentum, state.training_metrics, ) mi_pth_root = functools.partial( matrix_inverse_pth_root, ridge_epsilon=matrix_epsilon, precision=precision, relative_matrix_epsilon=relative_matrix_epsilon, lobpcg_topk_precondition=lobpcg_topk_precondition, lobpcg_max_iter=lobpcg_max_iter, ) def _matrix_inverse_pth_root_vmap(xs, ps): return jax.vmap(mi_pth_root)(xs, ps) def _quantized_matrix_inverse_pth_root_vmap(qxs, qds, qbs, ps): def _quantized_to_float(qx, qd, qb): qv = QuantizedValue(qx, qd, qb, qx.dtype, True, list(qx.shape)) return qv.to_float() def matrix_inverse_pth_root_wrapper(qx, qd, qb, p): v = _quantized_to_float(qx, qd, qb) preconditioner, error = mi_pth_root(v, p) qp = QuantizedValue.from_float_value(preconditioner, qx.dtype, True) return qp.quantized, qp.diagonal, qp.bucket_size, error return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps) def _matrix_inverse_pth_root_pjit(xs, ps, statistics_partition_spec=None): # Partition the concatenated statistics matrix across all cores. pspec_for_partition = preconditioner_partition_spec partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition) if preconditioner_partition_spec: partitioned_ps_spec = pjit.PartitionSpec(preconditioner_partition_spec[0]) else: partitioned_ps_spec = None partitioned_ps = pjit.with_sharding_constraint(ps, partitioned_ps_spec) # Run matrix inverse pth root on each shard. partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap( partitioned_xs, partitioned_ps ) # Reshard output to have the same PSpec as input. This is required to avoid # vmap seeing the full set of statistics. partitioned_preconditioners = pjit.with_sharding_constraint( partitioned_preconditioners, pspec_for_partition ) # Recombine the outputs at each core. preconditioners = pjit.with_sharding_constraint( partitioned_preconditioners, statistics_partition_spec ) errors = pjit.with_sharding_constraint(partitioned_errors, pjit.PartitionSpec()) return preconditioners, errors def _pmap_compute_preconditioners( states, step, statistics, num_statistics_per_state, original_shapes, exponents, max_size, prev_preconditioners, ): """Computes preconditioners for given statistics in states in PMAP mode. Args: states: A list of optimizer states. step: Current step number statistics: A list of statistics for all variables (for every dim) num_statistics_per_state: Number of statistis per state to reconstruct output states. original_shapes: A list of shapes of the statistics. exponents: Exponent power to use for inverse-pth roots. max_size: Maximum dim of the statistics to pad. prev_preconditioners: Previously available preconditioner. Returns: New optimizer states after computing the preconditioner. """ if batch_axis_name: num_devices = lax.psum(1, batch_axis_name) else: num_devices = 1 num_statistics = len(statistics) # Pad statistics and exponents to next multiple of num_devices. packed_statistics = [pad_square_matrix(stat, max_size) for stat in statistics] to_pad = -num_statistics % num_devices packed_statistics.extend( [jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad)] ) exponents.extend([1 for _ in range(to_pad)]) if not packed_statistics: return states all_statistics = batch(packed_statistics, num_devices) all_exponents = batch(exponents, num_devices) def _internal_inverse_pth_root_all(): if batch_axis_name: current_replica = lax.axis_index(batch_axis_name) preconditioners, errors = _matrix_inverse_pth_root_vmap( all_statistics[current_replica], all_exponents[current_replica] ) preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name) errors = jax.lax.all_gather(errors, batch_axis_name) preconditioners_flat = unbatch(preconditioners) errors_flat = unbatch(errors) else: preconditioners, errors = _matrix_inverse_pth_root_vmap( all_statistics[0], all_exponents[0] ) preconditioners_flat = unbatch(jnp.stack([preconditioners])) errors_flat = unbatch(jnp.stack([errors])) return preconditioners_flat, errors_flat if preconditioning_compute_steps == 1: preconditioners_flat, errors_flat = _internal_inverse_pth_root_all() else: # Passing statistics instead of preconditioners as they are similarly # shaped tensors. Note statistics will be ignored as we are passing in # a large init value for error. preconditioners_init = packed_statistics errors_init = [inverse_failure_threshold] * len(packed_statistics) init_state = [preconditioners_init, errors_init] perform_step = step % preconditioning_compute_steps == 0 preconditioners_flat, errors_flat = efficient_cond( perform_step, _internal_inverse_pth_root_all, init_state ) def _skip(error): condition = jnp.logical_or( jnp.isnan(error), error >= inverse_failure_threshold ) return condition.astype(error.dtype) def _select_preconditioner(error, new_p, old_p): return lax.cond( _skip(error), lambda _: old_p, lambda _: new_p, operand=None ) new_preconditioners_flat = [] new_errors_flat = [] for p, shape, prev_p, error in zip( preconditioners_flat, original_shapes, prev_preconditioners, errors_flat ): new_preconditioners_flat.append( _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p) ) new_errors_flat.append(error) assert len(states) == len(num_statistics_per_state) assert len(new_preconditioners_flat) == num_statistics assert len(new_errors_flat) == num_statistics # Add back empty preconditioners so we that we can set the optimizer state. preconditioners_for_states = [] idx = 0 errors_for_states = [] for num_statistics, state in zip(num_statistics_per_state, states): if num_statistics == 0: preconditioners_for_states.append([]) errors_for_states.append(jnp.array(0, jnp.float32)) else: preconditioners_for_state = new_preconditioners_flat[ idx : idx + num_statistics ] assert len(state.statistics) == len(preconditioners_for_state) preconditioners_for_states.append(preconditioners_for_state) errors_for_state = jnp.stack( new_errors_flat[idx : idx + num_statistics] ) assert len(state.statistics) == len(errors_for_state) errors_for_states.append(errors_for_state) idx += num_statistics new_states = [] for state, new_preconditioners, new_errors in zip( states, preconditioners_for_states, errors_for_states ): if state.statistics: new_errors = jnp.where( jnp.logical_and( new_errors > 0.0, new_errors != inverse_failure_threshold ), new_errors, state.training_metrics.inverse_pth_root_errors, ) new_training_metrics = TrainingMetrics(new_errors) new_states.append( ParameterStats( state.diagonal_statistics, state.statistics, new_preconditioners, state.diagonal_momentum, state.momentum, new_training_metrics, ) ) return new_states def _pmap_quantized_compute_preconditioners( states, step, statistics, num_statistics_per_state, original_shapes, exponents, max_size, prev_preconditioners, ): """Computes preconditioners for given statistics in states in PMAP mode. For quantization, each statistic is represented by three values: quantized matrix, diagonal, and bucket sizes, we run inverse pth-roots without ever recreating the original matrix in f32. Args: states: A list of optimizer states. step: Current step number statistics: A list of statistics for all variables (for every dim) num_statistics_per_state: Number of statistis per state to reconstruct output states. original_shapes: A list of shapes of the statistics. exponents: Exponent power to use for inverse-pth roots. max_size: Maximum dim of the statistics to pad. prev_preconditioners: Previously available preconditioner. Returns: New optimizer states after computing the preconditioner. """ num_devices = lax.psum(1, batch_axis_name) num_statistics = len(statistics) quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers() # Complexity here is around: shapes needing be statically shaped, # our custom quantization type requires a different type of packing. # Parallel tensors: # quantized [dxd] # diagonals [d] f32 # bucket_sizes [d] f32 packed_quantized_statistics = [ pad_square_matrix(stat.quantized, max_size) for stat in statistics ] packed_quantized_diagonals = [ pad_vector(stat.diagonal, max_size) for stat in statistics ] packed_quantized_bucket_sizes = [ pad_vector(stat.bucket_size, max_size) for stat in statistics ] to_pad = -num_statistics % num_devices padded_eye = jnp.eye(max_size, dtype=jnp.float32) quantized_eye = QuantizedValue.from_float_value( padded_eye, quantized_dtype, True ) packed_quantized_statistics.extend( [quantized_eye.quantized for _ in range(to_pad)] ) packed_quantized_diagonals.extend( [quantized_eye.diagonal for _ in range(to_pad)] ) packed_quantized_bucket_sizes.extend( [quantized_eye.bucket_size for _ in range(to_pad)] ) exponents.extend([1 for _ in range(to_pad)]) if not packed_quantized_statistics: return states all_quantized_statistics = batch(packed_quantized_statistics, num_devices) all_quantized_diagonals = batch(packed_quantized_diagonals, num_devices) all_quantized_bucket_sizes = batch(packed_quantized_bucket_sizes, num_devices) all_exponents = batch(exponents, num_devices) def _internal_inverse_pth_root_all(): current_replica = lax.axis_index(batch_axis_name) ( quantized_preconditioners, quantized_diagonals, quantized_bucket_sizes, errors, ) = _quantized_matrix_inverse_pth_root_vmap( all_quantized_statistics[current_replica], all_quantized_diagonals[current_replica], all_quantized_bucket_sizes[current_replica], all_exponents[current_replica], ) quantized_preconditioners = jax.lax.all_gather( quantized_preconditioners, batch_axis_name ) quantized_diagonals = jax.lax.all_gather( quantized_diagonals, batch_axis_name ) quantized_bucket_sizes = jax.lax.all_gather( quantized_bucket_sizes, batch_axis_name ) errors = jax.lax.all_gather(errors, batch_axis_name) quantized_preconditioners_flat = unbatch(quantized_preconditioners) quantized_diagonals_flat = unbatch(quantized_diagonals) quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes) errors_flat = unbatch(errors) return ( quantized_preconditioners_flat, quantized_diagonals_flat, quantized_bucket_sizes_flat, errors_flat, ) if preconditioning_compute_steps == 1: ( quantized_preconditioners_flat, quantized_diagonals_flat, quantized_bucket_sizes_flat, errors_flat, ) = _internal_inverse_pth_root_all() else: # Passing statistics instead of preconditioners as they are similarly # shaped tensors. Note statistics will be ignored as we are passing in # a large init value for error. quantized_preconditioners_init = packed_quantized_statistics quantized_diagonals_init = packed_quantized_diagonals quantized_bucket_sizes_init = packed_quantized_bucket_sizes errors_init = [inverse_failure_threshold] * len( quantized_preconditioners_init ) init_state = [ quantized_preconditioners_init, quantized_diagonals_init, quantized_bucket_sizes_init, errors_init, ] perform_step = step % preconditioning_compute_steps == 0 ( quantized_preconditioners_flat, quantized_diagonals_flat, quantized_bucket_sizes_flat, errors_flat, ) = efficient_cond(perform_step, _internal_inverse_pth_root_all, init_state) def _skip(error): condition = jnp.logical_or( jnp.isnan(error), error >= inverse_failure_threshold ) return condition.astype(error.dtype) def _select_preconditioner(error, new_p, old_p): return lax.cond( _skip(error), lambda _: old_p, lambda _: new_p, operand=None ) new_quantized_preconditioners_flat = [] new_quantized_diagonals_flat = [] new_quantized_bucket_sizes_flat = [] new_errors_flat = [] for p, d, b, shape, prev_p, error in zip( quantized_preconditioners_flat, quantized_diagonals_flat, quantized_bucket_sizes_flat, original_shapes, prev_preconditioners, errors_flat, ): new_quantized_preconditioners_flat.append( _select_preconditioner( error, p[: shape[0], : shape[1]], prev_p.quantized ) ) new_quantized_diagonals_flat.append( _select_preconditioner(error, d[: shape[0]], prev_p.diagonal) ) new_quantized_bucket_sizes_flat.append( _select_preconditioner(error, b[: shape[0]], prev_p.bucket_size) ) new_errors_flat.append(error) assert len(states) == len(num_statistics_per_state) assert len(new_quantized_preconditioners_flat) == num_statistics assert len(new_quantized_diagonals_flat) == num_statistics assert len(new_quantized_bucket_sizes_flat) == num_statistics # Add back empty preconditioners so we that we can set the optimizer state. preconditioners_for_states = [] errors_for_states = [] idx = 0 for num_statistics, state in zip(num_statistics_per_state, states): if num_statistics == 0: preconditioners_for_states.append([]) errors_for_states.append(jnp.array(0, jnp.float32)) else: quantized_preconditioners_for_state = ( new_quantized_preconditioners_flat[idx : idx + num_statistics] ) quantized_diagonals_for_state = new_quantized_diagonals_flat[ idx : idx + num_statistics ] quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[ idx : idx + num_statistics ] errors_for_state = jnp.stack( new_errors_flat[idx : idx + num_statistics] ) assert len(state.statistics) == len(quantized_preconditioners_for_state) assert len(state.statistics) == len(quantized_diagonals_for_state) assert len(state.statistics) == len(quantized_bucket_sizes_for_state) assert len(state.statistics) == len(errors_for_state) quantized_preconditioners = [] for qv, qd, qb in zip( quantized_preconditioners_for_state, quantized_diagonals_for_state, quantized_bucket_sizes_for_state, ): quantized_preconditioners.append( QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape)) ) preconditioners_for_states.append(quantized_preconditioners) errors_for_states.append(errors_for_state) idx += num_statistics new_states = [] for state, new_preconditioners, new_errors in zip( states, preconditioners_for_states, errors_for_states ): if state.statistics: new_errors = jnp.where( jnp.logical_and( new_errors > 0.0, new_errors != inverse_failure_threshold ), new_errors, state.training_metrics.inverse_pth_root_errors, ) new_training_metrics = TrainingMetrics(new_errors) new_states.append( ParameterStats( state.diagonal_statistics, state.statistics, new_preconditioners, state.diagonal_momentum, state.momentum, new_training_metrics, ) ) return new_states def _pjit_compute_preconditioners( states, step, statistics, num_statistics_per_state, original_shapes, exponents, max_size, prev_preconditioners, ): """Computes preconditioners for given statistics in states in PJIT mode. Args: states: A list of optimizer states. step: Current step number statistics: A list of statistics for all variables (for every dim) num_statistics_per_state: Number of statistis per state to reconstruct output states. original_shapes: A list of shapes of the statistics. exponents: Exponent power to use for inverse-pth roots. max_size: Maximum dim of the statistics to pad. prev_preconditioners: Previously available preconditioner. Returns: New optimizer states after computing the preconditioner. """ num_statistics = len(statistics) to_pad = -num_statistics % num_devices_for_pjit padded_statistics = [pad_square_matrix(stat, max_size) for stat in statistics] padded_statistics.extend( [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)] ) exponents.extend([1 for _ in range(to_pad)]) all_statistics = jnp.stack(padded_statistics) all_exponents = jnp.stack(exponents) def _internal_inverse_pth_root_all(): preconditioners, errors = _matrix_inverse_pth_root_pjit( all_statistics, all_exponents ) b1 = preconditioners.shape[0] def split(batched_values): return [ jnp.squeeze(v) for v in jnp.split(batched_values, indices_or_sections=b1, axis=0) ] return split(preconditioners), split(errors) if preconditioning_compute_steps == 1: preconditioners_flat, errors_flat = _internal_inverse_pth_root_all() else: # Passing statistics instead of preconditioners as they are similarly # shaped tensors. Note statistics will be ignored as we are passing in # a large init value for error. preconditioners_init = padded_statistics errors_init = [inverse_failure_threshold] * len(padded_statistics) init_state = [preconditioners_init, errors_init] perform_step = step % preconditioning_compute_steps == 0 preconditioners_flat, errors_flat = efficient_cond( perform_step, _internal_inverse_pth_root_all, init_state ) def _skip(error): condition = jnp.logical_or( jnp.isnan(error), error >= inverse_failure_threshold ) return condition.astype(error.dtype) def _select_preconditioner(error, new_p, old_p): return lax.cond( _skip(error), lambda _: old_p, lambda _: new_p, operand=None ) new_preconditioners_flat = [] new_errors_flat = [] for p, shape, prev_p, error in zip( preconditioners_flat, original_shapes, prev_preconditioners, errors_flat ): new_preconditioners_flat.append( _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p) ) new_errors_flat.append(error) assert len(states) == len(num_statistics_per_state) assert len(new_preconditioners_flat) == num_statistics # Add back empty preconditioners so we that we can set the optimizer state. preconditioners_for_states = [] errors_for_states = [] idx = 0 for num_statistics, state in zip(num_statistics_per_state, states): if num_statistics == 0: preconditioners_for_states.append([]) errors_for_states.append(jnp.array(0, jnp.float32)) else: preconditioners_for_state = new_preconditioners_flat[ idx : idx + num_statistics ] assert len(state.statistics) == len(preconditioners_for_state) preconditioners_for_states.append(preconditioners_for_state) errors_for_state = jnp.stack( new_errors_flat[idx : idx + num_statistics] ) assert len(state.statistics) == len(errors_for_state) errors_for_states.append(errors_for_state) idx += num_statistics new_states = [] for state, new_preconditioners, new_errors in zip( states, preconditioners_for_states, errors_for_states ): if state.statistics: new_errors = jnp.where( jnp.logical_and( new_errors > 0.0, new_errors != inverse_failure_threshold ), new_errors, state.training_metrics.inverse_pth_root_errors, ) new_training_metrics = TrainingMetrics(new_errors) new_states.append( ParameterStats( state.diagonal_statistics, state.statistics, new_preconditioners, state.diagonal_momentum, state.momentum, new_training_metrics, ) ) return new_states def _compute_preconditioners(states, params, step): """Computes preconditioners for given statistics in states. Args: states: A list of optimizer states. params: A list of params. step: Current step number Returns: New optimizer states after computing the preconditioner. """ statistics = [] num_statistics_per_state = [] original_shapes = [] exponents = [] max_size = 0 prev_preconditioners = [] for state, param in zip(states, params): num_statistics = len(state.statistics) num_statistics_per_state.append(num_statistics) original_shapes_for_state = [] if num_statistics > 0: preconditioner = preconditioner_from_params(param) for statistic in state.statistics: exponents.append( preconditioner.exponent_for_preconditioner() if exponent_override == 0 else exponent_override ) original_shapes_for_state.append(statistic.shape) max_size = max(max_size, statistic.shape[0]) statistics.extend(state.statistics) prev_preconditioners.extend(state.preconditioners) original_shapes.extend(original_shapes_for_state) if not shard_optimizer_states: # Quantization is only enabled if batch_axis_name is not set. quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers() if quantized_dtype == jnp.float32: return _pmap_compute_preconditioners( states, step, statistics, num_statistics_per_state, original_shapes, exponents, max_size, prev_preconditioners, ) else: return _pmap_quantized_compute_preconditioners( states, step, statistics, num_statistics_per_state, original_shapes, exponents, max_size, prev_preconditioners, ) else: return _pjit_compute_preconditioners( states, step, statistics, num_statistics_per_state, original_shapes, exponents, max_size, prev_preconditioners, ) def _transform_grad(grad, state, param, step): """Transform per-parameter gradients.""" preconditioner = preconditioner_from_params(param) sgd_update = grad new_diagonal_statistics = state.diagonal_statistics.to_float() if ( graft_type == GraftingType.ADAGRAD or graft_type == GraftingType.ADAGRAD_NORMALIZED ): scaled_grad = grad if graft_type == GraftingType.ADAGRAD_NORMALIZED: scaled_grad = grad / (jnp.linalg.norm(grad) + 1e-16) new_diagonal_statistics = state.diagonal_statistics.to_float() + jnp.square( scaled_grad ) adagrad_update = scaled_grad / ( jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon ) grafting_update = adagrad_update elif ( graft_type == GraftingType.RMSPROP or graft_type == GraftingType.RMSPROP_NORMALIZED ): scaled_grad = grad if graft_type == GraftingType.RMSPROP_NORMALIZED: scaled_grad = grad / (jnp.linalg.norm(grad) + 1e-16) w1 = beta2 w2 = beta2 if beta2 == 1.0 else (1.0 - beta2) new_diagonal_statistics = ( w1 * state.diagonal_statistics.to_float() + w2 * jnp.square(scaled_grad) ) rmsprop_update = scaled_grad / ( jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon ) if clip_by_scaled_gradient_norm: scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / ( jnp.sqrt(float(rmsprop_update.size)) ) clipping_denom = jnp.maximum( 1.0, scaled_grad_norm / clip_by_scaled_gradient_norm ) rmsprop_update /= clipping_denom grafting_update = rmsprop_update elif graft_type == GraftingType.SGD: grafting_update = sgd_update else: grafting_update = jnp.ones_like(sgd_update) * jnp.sign(sgd_update) lr = learning_rate if callable(learning_rate): lr = learning_rate(step) preconditioner_multiplier = lr if not decoupled_learning_rate else 1.0 grafting_update = grafting_update * preconditioner_multiplier precond_grad = grad if not _skip_preconditioning(param): precond_grad = preconditioner.preconditioned_grad( precond_grad, _maybe_dequantize_preconditioners(state.preconditioners) ) else: precond_grad = grafting_update grafting_update_norm = jnp.linalg.norm(grafting_update) precond_grad_norm = jnp.linalg.norm(precond_grad) multiplier = grafting_update_norm / (precond_grad_norm + 1e-16) shampoo_update = precond_grad * multiplier shampoo_update_with_wd = shampoo_update grafting_update_with_wd = grafting_update if weight_decay != 0 and not decoupled_weight_decay: shampoo_update_with_wd = shampoo_update + weight_decay * param grafting_update_with_wd = grafting_update + weight_decay * param w = (1.0 - beta1) if moving_average_for_momentum else 1.0 shampoo_update_with_wd_momentum = ( state.momentum.to_float() * beta1 + w * shampoo_update_with_wd ) grafting_update_with_wd_momentum = ( state.diagonal_momentum.to_float() * beta1 + w * grafting_update_with_wd ) run_shampoo = (step >= start_preconditioning_step).astype( grafting_update_with_wd_momentum.dtype ) momentum_update = ( run_shampoo * shampoo_update_with_wd_momentum + (1.0 - run_shampoo) * grafting_update_with_wd_momentum ) wd_update = ( run_shampoo * shampoo_update_with_wd + (1.0 - run_shampoo) * grafting_update_with_wd ) nesterov_momentum_update = momentum_update if nesterov: nesterov_momentum_update = w * wd_update + beta1 * momentum_update if weight_decay != 0 and decoupled_weight_decay: nesterov_momentum_update = ( nesterov_momentum_update + lr * weight_decay * param ) momentum_multiplier = lr if decoupled_learning_rate else 1.0 transformed_update = -1.0 * momentum_multiplier * nesterov_momentum_update new_diagonal_momentum = grafting_update_with_wd_momentum new_momentum = shampoo_update_with_wd_momentum param_stats = ParameterStats( _quantize_diagonal_statistics(new_diagonal_statistics), state.statistics, state.preconditioners, _quantize_momentum(new_diagonal_momentum), _quantize_momentum(new_momentum), state.training_metrics, ) return transformed_update, param_stats def update_fn(grads, state, params): """Transform the input gradient and update all statistics. Args: grads: the gradient tensors for the parameters and any custom gradients for preconditioners. state: a named tuple containing the state of the optimizer params: the parameters that should be updated. Returns: A tuple containing the new parameters and the new optimizer state. """ params_flat, treedef = jax.tree_flatten(params) stats_flat = treedef.flatten_up_to(state.stats) grads_flat = treedef.flatten_up_to(grads) stats_grads = grads_flat new_stats_flat = jax.tree_map( lambda g, s, p: _compute_stats(g, s, p, state.count), stats_grads, stats_flat, params_flat, ) new_stats_flat = _compute_preconditioners( new_stats_flat, params_flat, state.count ) outputs = jax.tree_map( lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat, ) updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) updates = jax.tree_unflatten(treedef, updates_flat) new_stats = jax.tree_unflatten(treedef, new_stats_flat) new_state = ShampooState(count=state.count + 1, stats=new_stats) return updates, new_state if shard_optimizer_states: # Hijacks the init_fn signature so we can return an OptState with # appropriate init_fns. opt_init_fn = sharded_init_fn def _init_fns(unused_params): return InitFnState( init_fn=opt_init_fn, pspec_fn=sharded_init_partition_spec_fn, shape_and_dtype_fn=sharded_init_shape_and_dtype_fn, ) opt_update_fn = sharded_update_fn return optax.GradientTransformation(_init_fns, opt_update_fn) else: return optax.GradientTransformation(init_fn, update_fn) ================================================ FILE: tools/train/scalable_shampoo/quantization_utils.py ================================================ # coding=utf-8 # Copyright 2022 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Helper routines for quantization.""" from typing import Any import chex import jax.numpy as jnp from flax import struct # pylint:disable=no-value-for-parameter @struct.dataclass class QuantizedValue: """State associated with quantized value.""" quantized: chex.Array diagonal: chex.Array # Diagonal (if extract_diagonal is set) bucket_size: chex.Array quantized_dtype: jnp.dtype = struct.field( pytree_node=False ) # Dtype for the quantized value. extract_diagonal: bool = struct.field(pytree_node=False) # In case its centered. shape: Any = struct.field(pytree_node=False) # Shape of the tensor. @classmethod def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False): if isinstance(fvalue, list) and not fvalue: return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, []) quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize( fvalue, quantized_dtype, extract_diagonal ) return QuantizedValue( quantized, diagonal_fvalue, bucket_size, quantized_dtype, extract_diagonal, list(quantized.shape), ) # Quantization is from Lingvo JAX optimizers. # We extend it for int16 quantization of PSD matrices. @classmethod def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False): """Returns quantized value and the bucket.""" if quantized_dtype == jnp.float32: return fvalue, [], [] elif quantized_dtype == jnp.bfloat16: return fvalue.astype(jnp.bfloat16), [], [] float_dtype = fvalue.dtype if quantized_dtype == jnp.int8: # value -128 is not used. num_buckets = jnp.array(127.0, dtype=float_dtype) elif quantized_dtype == jnp.int16: # value -32768 is not used. num_buckets = jnp.array(32767.0, dtype=float_dtype) else: raise ValueError(f"Quantized dtype {quantized_dtype} not supported.") # max value is mapped to num_buckets if extract_diagonal and fvalue.ndim != 2: raise ValueError( f"Input array {fvalue} must be 2D to work with extract_diagonal." ) diagonal_fvalue = [] if extract_diagonal: diagonal_fvalue = jnp.diag(fvalue) # Remove the diagonal entries. fvalue = fvalue - jnp.diag(diagonal_fvalue) # TODO(rohananil): Extend this by making use of information about the blocks # SM3 style which will be useful for diagonal statistics # We first decide the scale. if fvalue.ndim < 1: raise ValueError( f"Input array {fvalue} must have a strictly positive number of dimensions." ) max_abs = jnp.max(jnp.abs(fvalue), axis=0) bucket_size = max_abs / num_buckets bs_expanded = bucket_size[jnp.newaxis, Ellipsis] # To avoid divide by 0.0 bs_nonzero = jnp.where( bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded) ) ratio = fvalue / bs_nonzero # We use rounding to remove bias. quantized = jnp.round(ratio) return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size def to_float(self): """Returns the float value.""" if isinstance(self.quantized, list) and not self.quantized: return self.quantized if self.quantized_dtype == jnp.float32: return self.quantized if self.quantized_dtype == jnp.bfloat16: return self.quantized.astype(jnp.float32) float_dtype = self.bucket_size.dtype bucket_size = self.bucket_size[jnp.newaxis, Ellipsis] val = self.quantized.astype(float_dtype) * bucket_size if self.extract_diagonal: val += jnp.diag(self.diagonal) return val ================================================ FILE: tools/train/scalable_shampoo/sm3.py ================================================ # coding=utf-8 # Copyright 2022 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # An implementation of SM3 from: # # Memory-Efficient Adaptive Optimization, https://arxiv.org/pdf/1901.11150.pdf # Rohan Anil, Vineet Gupta, Tomer Koren, Yoram Singer # # Author: Rohan Anil (rohananil at google dot com) # """SM3 Implementation.""" import functools from typing import Any, NamedTuple import chex import jax import jax.numpy as jnp import optax from .quantization_utils import QuantizedValue class SM3State(NamedTuple): count: chex.Array stats: Any # Per parameter optimizer state used in data-parallel training. class ParameterStats(NamedTuple): """State associated to each parameter of the model being trained.""" diagonal_statistics: chex.Array # Accumulator for diagonal preconditioner diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner def sm3( learning_rate, beta1=0.9, beta2=0.999, diagonal_epsilon=1e-10, normalize_grads=False ): """SM3 optimizer. Memory-Efficient Adaptive Optimization, Rohan Anil, Vineet Gupta, Tomer Koren, Yoram Singer https://arxiv.org/abs/1901.11150 Args: learning_rate: the step size used to update the parameters. beta1: momentum parameter. beta2: second moment averaging parameter. diagonal_epsilon: epsilon for sm3 normalize_grads: Whether to normalize grads. Author finds it useful when grads are high variance. Returns: a GradientTransformation. """ def _quantize_momentum(momentum_statistics): return QuantizedValue.from_float_value(momentum_statistics, jnp.int8) def init_fn(params): """Initialise the optimiser's state.""" def _init(param): accumulators = [jnp.zeros([s]) for s in param.shape] momentum = _quantize_momentum(jnp.zeros_like(param)) return ParameterStats(accumulators, momentum) return SM3State( count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params) ) def _get_expanded_shape(shape, i): rank = len(shape) # Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i. # For eg: i = 1 returns [1, N, 1]. return [1] * i + [shape[i]] + [1] * (rank - i - 1) def _moving_averages(grad, accumulators): w = (1.0 - beta2) if beta2 != 1.0 else 1.0 if grad.ndim < 2: return beta2 * accumulators[0] + w * grad**2 else: min_accumulator = functools.reduce(jnp.minimum, accumulators) return beta2 * min_accumulator + w * grad**2 def _moving_averages_momentum(grad, momentum): w = (1.0 - beta1) if beta1 != 1.0 else 1.0 return beta1 * momentum.to_float() + w * grad def _sketch_diagonal_statistics(grad, updated_diagonal_statistics): all_diagonal_statistics = [] for i in range(grad.ndim): axes = list(range(i)) + list(range(i + 1, grad.ndim)) dim_diagonal_statistics = jnp.max(updated_diagonal_statistics, axis=axes) all_diagonal_statistics.append(dim_diagonal_statistics) if grad.ndim == 1: all_diagonal_statistics[0] = updated_diagonal_statistics return all_diagonal_statistics def update_fn(updates, state, params=None): del params stats = state.stats if normalize_grads: updates = jax.tree_map(lambda g: g / (jnp.linalg.norm(g) + 1e-16), updates) # Reshape all vectors into N-d tensors to compute min over them. # [n], [m] -> [n, 1], [1, m] expanded_diagonal_statistics = jax.tree_map( lambda grad, state: [ # pylint:disable=g-long-lambda jnp.reshape( state.diagonal_statistics[i], _get_expanded_shape(grad.shape, i) ) for i in range(grad.ndim) ], updates, stats, ) # Compute new diagonal statistics new_diagonal_statistics = jax.tree_map( _moving_averages, updates, expanded_diagonal_statistics ) # Compute preconditioners (1/sqrt(s)) where s is the statistics. new_preconditioners = jax.tree_map( lambda t: 1.0 / jnp.sqrt(t + diagonal_epsilon), new_diagonal_statistics ) preconditioned_grads = jax.tree_map( lambda g, p: g * p, updates, new_preconditioners ) # Compute updated momentum (also handle quantization) updated_momentum = jax.tree_map( lambda preconditioned_grad, state: _moving_averages_momentum( # pylint:disable=g-long-lambda preconditioned_grad, state.diagonal_momentum ), preconditioned_grads, stats, ) # Update diagonal statistics. updated_diagonal_statistics = jax.tree_map( _sketch_diagonal_statistics, updates, new_diagonal_statistics ) # Update momentum. new_sm3_stats = jax.tree_map( lambda momentum, diagonal_stats: ParameterStats( # pylint:disable=g-long-lambda diagonal_stats, _quantize_momentum(momentum) ), updated_momentum, updated_diagonal_statistics, ) lr = learning_rate if callable(learning_rate): lr = learning_rate(state.count) new_updates = jax.tree_map(lambda pg: -lr * pg, updated_momentum) return new_updates, SM3State(count=state.count + 1, stats=new_sm3_stats) return optax.GradientTransformation(init_fn, update_fn) ================================================ FILE: tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py ================================================ # coding=utf-8 # Copyright 2022 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """JAX Ops for symmetric matrices used by the Shampoo optimizer.""" import functools from typing import Any, List, Optional, Sequence, Union import jax import jax.numpy as jnp import numpy as np from flax import struct from jax import lax @struct.dataclass class SlicedSymmetricMatrix: """A symmetric matrix represented by lower-triangular block row slices. For example, the symmetric matrix M = [[a, b^T], [b, c]] would be represented by the block rows a and [b, c]. The matrix may be batched, in which case each entry of block_rows may have dimension greater than 2. The last two dimensions represent the rows and cols. """ block_rows: List[jnp.ndarray] def product_with_transpose( mat1, mat2, axes, precision=lax.Precision.DEFAULT, ): """Returns mat1 * mat2^T for two matrices (possibly batched). The rows and columns are the last two dimensions for each matrix. Args: mat1: First matrix. mat2: Second matrix. axes: The axes over which to apply the product. precision: JAX precision to use for the multiplication. """ return jnp.tensordot(a=mat1, b=mat2, axes=axes, precision=precision) @functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision")) def sliced_transposed_product( mat, block_size, axes=(-1,), precision=lax.Precision.DEFAULT, ): """Returns the blocked slices representing a symmetric contraction. Specifically, the output is a contraction of the input mat with itself, in the specified axes. Args: mat: The matrix for which we will compute a contraction with itself. block_size: The size of row blocks to compute. axes: Axes to use for the contraction. precision: The precision to use in each computation. Raises: ValueError: Raised when the specified block size does not evenly divide the number of rows of the input mat. """ rank = len(mat.shape) def _make_axis_positive(ax): assert -rank <= ax < rank return ax + rank if ax < 0 else ax positive_axes = [_make_axis_positive(ax) for ax in axes] assert len(positive_axes) == len(axes) remaining_axes = set(range(rank)) - set(positive_axes) assert len(remaining_axes) == 1 remaining_ax = remaining_axes.pop() num_rows = mat.shape[remaining_ax] if num_rows % block_size != 0: raise ValueError( "The row dimension must be divisible by block_size. " f"Instead got row dimension={num_rows} and block_size={block_size}." ) block_rows = [] for i in range(num_rows // block_size): start_indices = [0] * rank start_indices[remaining_ax] = i * block_size slice_sizes = list(mat.shape) slice_sizes[remaining_ax] = block_size slice_sizes_full = list(mat.shape) slice_sizes_full[remaining_ax] = (i + 1) * block_size block_rows.append( product_with_transpose( lax.dynamic_slice( mat, start_indices=start_indices, slice_sizes=slice_sizes ), lax.dynamic_slice( mat, start_indices=[0] * rank, slice_sizes=slice_sizes_full ), axes=(axes, axes), precision=precision, ) ) return SlicedSymmetricMatrix(block_rows=block_rows) @functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision")) def sliced_transposed_product_concat( mat, block_size, axes=(-1,), precision=lax.Precision.DEFAULT, ): """Returns the concatenated slices representing mat*mat^T. Args: mat: The matrix for which we will compute mat*mat^T. It does not need to be square, and may be batched. block_size: The size of row blocks to compute. axes: Axes to use for the contraction. precision: The precision to use in each computation. Raises: ValueError: Raised when the specified block size does not evenly divide the number of rows of the input mat. """ sliced_symmetric_matrix = sliced_transposed_product( mat=mat, block_size=block_size, axes=axes, precision=precision ) return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1) @jax.jit def materialize_matrix(symmetric_matrix): """Returns a materialized symmetric matrix. Args: symmetric_matrix: the matrix represented by lower-triangular block slices. """ block_rows = symmetric_matrix.block_rows block_size = block_rows[0].shape[-2] num_blocks = len(block_rows) # Slice the lower-triangular and diagonal blocks into blocks. blocks = [ [ block_row[Ellipsis, i * block_size : (i + 1) * block_size] for i in range(k + 1) ] for k, block_row in enumerate(block_rows) ] # Generate the (off-diagonal) upper-triangular blocks. off_diags = [[] for _ in range(num_blocks - 1)] for k, block_row in enumerate(block_rows[1:]): for i in range(k + 1): off_diags[i].append( jnp.swapaxes( a=block_row[Ellipsis, i * block_size : (i + 1) * block_size], axis1=-1, axis2=-2, ) ) return jnp.block( [row + row_t for row, row_t in zip(blocks[:-1], off_diags)] + [blocks[-1]] ) @functools.partial(jax.jit, static_argnames="num_blocks") def materialize_matrix_from_concat( block_rows_concat, num_blocks=None, ): """Returns a materialized symmetric matrix from concatenated slices. Args: block_rows_concat: The matrix represented as the concatenated lower-triangular blocks. num_blocks: The number of block-rows used to represent the symmetric matrix. If not specified, it is inferred from the shape of block_rows_concat. """ if num_blocks is None: num_blocks = find_num_blocks(block_rows_concat) block_size = block_rows_concat.shape[-2] block_rows = [ block_rows_concat[ Ellipsis, (k * (k + 1)) // 2 * block_size : (((k + 1) * (k + 2)) // 2 + 1) * block_size, ] for k in range(num_blocks) ] return materialize_matrix(SlicedSymmetricMatrix(block_rows=block_rows)) @functools.partial(jax.jit, static_argnames=("alpha", "beta", "axes")) def update_sliced_rows( symmetric_matrix, mat, alpha, beta, axes=(-1,), ): """Implements the blocked equivalent of SYRK. Specifically, the symmetric matrix (represented using lower-triangular block rows) is updated using the sliced product of mat. Args: symmetric_matrix: The symmetric matrix to update. mat: The matrix to use for the update = mat * mat^T. The number of rows should match that of symmetric_matrix. alpha: The weight for the update. beta: The weight for the original symmetric matrix. axes: Axes to use for the contraction of the update. Returns: The updated rows of alpha * mat * mat^T + beta * symmetric_matrix. """ block_size = symmetric_matrix.block_rows[0].shape[-2] sym_prod = sliced_transposed_product(mat=mat, block_size=block_size, axes=axes) return SlicedSymmetricMatrix( block_rows=[ update * alpha + row * beta for update, row in zip(sym_prod.block_rows, symmetric_matrix.block_rows) ] ) def num_blocks_from_total_blocks(total_blocks): """Returns the number of blocks (i.e. block rows) from the total blocks. This is the inverse of the function x -> x*(x+1)/2. For example, the matrix M = [[A, B^T], [B, C]] may be represented using a total of 3 blocks ([A, B, C]). The number of corresponding block rows is 2. Args: total_blocks: The total blocks used to represent the matrix. """ num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32) if (num_blocks * (num_blocks + 1)) / 2 != total_blocks: raise ValueError( f"total_blocks={total_blocks} does not correspond to " "a symmetric matrix. It must have the form total_blocks = x*(x+1)/2." ) return num_blocks def find_num_blocks(block_rows_concat): """Returns the number of (row) blocks representing the concatenated matrix. For example, an input with dimensions [256, 2560] represents 10 square blocks, which matches 4 lower-triangular block rows (1+2+3+4). So this function will return 4. Use ordinary numpy functions here so that the returned value is static. Args: block_rows_concat: The concatenated block array. Raises: ValueError: When the dimensions of the matrix do not correspond to a lower triangular block representation. """ # Compute the number of square blocks used to represent the matrix. total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2] # Determine the number of block rows by inverting y = x*(x+1)/2. return num_blocks_from_total_blocks(total_blocks) @functools.partial(jax.jit, static_argnames="block_size") def slice_symmetric_matrix( mat, block_size, ): """Returns sliced row blocks. Args: mat: A symmetric matrix. block_size: The size of the row slices. """ num_rows = mat.shape[-2] num_cols = mat.shape[-1] if num_rows != num_cols: raise ValueError("mat is not square.") if num_rows % block_size != 0: raise ValueError( f"block size does not evenly divide rows. num_rows={num_rows}, block_size={block_size}" ) return SlicedSymmetricMatrix( block_rows=[ mat[ Ellipsis, i * block_size : (i + 1) * block_size, 0 : (i + 1) * block_size, ] for i in range(num_rows // block_size) ] ) @functools.partial(jax.jit, static_argnames="block_size") def slice_symmetric_matrix_concat( mat, block_size, ): """Returns the concatenated sliced row blocks. Args: mat: A symmetric matrix. block_size: The size of the row slices. """ sliced_symmetric_matrix = slice_symmetric_matrix(mat=mat, block_size=block_size) return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1) def sliced_matrix_diag(mat): """Returns the diagonal of the symmetric matrix. Args: mat: The symmetric matrix represented in concatenated block form. """ rows, cols = mat.shape total_blocks = cols // rows num_blocks = num_blocks_from_total_blocks(total_blocks) diags = [] for i in range(num_blocks): last_index = rows * ((i + 2) * (i + 1)) // 2 first_index = last_index - rows diags.append(jnp.diag(mat[Ellipsis, first_index:last_index])) return jnp.concatenate(diags, axis=-1) def diag_as_concat(diag, block_size): """Returns the representation of a diagonal matrix in symmetric block form. Args: diag: The 1D array for the diagonals. block_size: The size of blocks to use. Must divide the length of diag. """ assert len(diag.shape) == 1 # diag must be 1D. assert len(diag) % block_size == 0 num_diag_blocks = len(diag) // block_size blocks = [] for i in range(num_diag_blocks): blocks.append(jnp.zeros(shape=(block_size, block_size * i), dtype=diag.dtype)) blocks.append(jnp.diag(diag[i * block_size : (i + 1) * block_size])) return jnp.concatenate(blocks, axis=-1) def row_abs_maxes(mat): """Returns the max of the absolute values of the rows of the full matrix. For example the symmetric matrix M = [[1, 6], [6, 2]] is represented using mat = [1, 6, 2] with block_size = 1. In this case the function returns the aboslute row maxes of the original symmetric matrix, [6, 6]. Args: mat: The symmetric matrix represented as the concatenated blocks. """ rows, cols = mat.shape # Find col and row max for each block. col_maxes = [] row_maxes = [] for i in range(cols // rows): block = jnp.abs(mat[Ellipsis, i * rows : (i + 1) * rows]) col_maxes.append(jnp.max(block, axis=1)) row_maxes.append(jnp.max(block, axis=0)) # global row max from block maxes. num_blocks = num_blocks_from_total_blocks(cols // rows) maxes = [] for i in range(num_blocks): maxes.append( jnp.concatenate( row_maxes[(i * (i + 1) // 2) : ((i + 2) * (i + 1) // 2)] + [ col_maxes[((j + 1) * (j + 2)) // 2 - (j - i + 1)] for j in range(i + 1, num_blocks) ], axis=-1, ) ) return jnp.max(jnp.stack(maxes), axis=0) def times_vector(mat, vec): """Returns the symmetric block-concatenated matrix multiplied by a vector. Specifically, each value in the vector is multiplied by a row of the full matrix. That is, the vector is broadcast and multiplied element-wise. Note this would be the transpose of full_mat * vec if full_mat represented the full symmetric matrix. Args: mat: The symmetric matrix represented as the concatenated blocks. vec: The vector, having the same dimension as the materialized matrix. """ rows, cols = mat.shape num_blocks = num_blocks_from_total_blocks(cols // rows) multiplied = [] for i in range(num_blocks): mat_block = mat[ Ellipsis, rows * ((i + 1) * i) // 2 : rows * ((i + 1) * (i + 2)) // 2 ] vec_block = vec[Ellipsis, rows * i : rows * (i + 1)] multiplied.append(jnp.einsum("...ij,...i->ij", mat_block, vec_block)) return jnp.concatenate(multiplied, axis=-1) ================================================ FILE: tools/train/sweep.yaml ================================================ program: train.py project: dalle-mini method: random metric: name: eval/loss goal: minimize parameters: optim: value: distributed_shampoo learning_rate: distribution: log_uniform # from exp(min) to exp(max) min: -9.2 max: -6.9 tokenizer_name: value: boris/dalle-mini-tokenizer config_name: value: ./config/mini dtype: value: bfloat16 dataset_repo_or_path: value: ./data per_device_train_batch_size: value: 64 per_device_eval_batch_size: value: 64 gradient_accumulation_steps: value: 1 warmup_steps: value: 1000 num_train_epochs: value: 1 max_train_samples: value: 1000000 logging_steps: value: 40 eval_steps: value: 200 command: - python3 - ${program} - "--streaming" - "--output_dir" - "./output" - "--overwrite_output_dir" - "--do_train" - "--do_eval" - ${args} ================================================ FILE: tools/train/train.py ================================================ #!/usr/bin/env python # coding=utf-8 # Copyright 2021-2022 The HuggingFace & DALL·E Mini team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Training DALL·E Mini. Script adapted from run_summarization_flax.py """ import io import logging import os import sys import tempfile import time from dataclasses import asdict, dataclass, field from functools import partial from pathlib import Path from typing import Any, Callable, NamedTuple, Optional import datasets import flax import jax import jax.numpy as jnp import jaxlib import numpy as np import optax import transformers import wandb from datasets import Dataset from flax import core, struct, traverse_util from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.serialization import from_bytes, to_bytes from flax.training.common_utils import onehot from jax.experimental import PartitionSpec, maps from jax.experimental.compilation_cache import compilation_cache as cc from jax.experimental.pjit import pjit, with_sharding_constraint from scalable_shampoo.distributed_shampoo import GraftingType, distributed_shampoo from tqdm import tqdm from transformers import HfArgumentParser import dalle_mini from dalle_mini.data import Dataset from dalle_mini.model import ( DalleBart, DalleBartConfig, DalleBartTokenizer, set_partitions, ) try: from google.cloud import storage except: storage = None logger = logging.getLogger(__name__) cc.initialize_cache("jax_cache") @dataclass class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. """ model_name_or_path: Optional[str] = field( default=None, metadata={ "help": "The model checkpoint for weights initialization. " "Don't set if you want to train a model from scratch. " "W&B artifact references are supported in addition to the sources supported by `PreTrainedModel`." }, ) config_name: Optional[str] = field( default=None, metadata={ "help": "Pretrained config name or path if not the same as model_name_or_path" }, ) tokenizer_name: Optional[str] = field( default=None, metadata={ "help": "Pretrained tokenizer name or path if not the same as model_name_or_path" }, ) dtype: Optional[str] = field( default="float32", metadata={ "help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`." }, ) restore_state: Optional[bool] = field( default=False, metadata={ "help": "Restore optimizer and training state. Can be True (will retrieve associated wandb artifact), a local directory or a Google bucket path." }, ) dropout: Optional[float] = field( default=None, metadata={"help": "Dropout rate. Overwrites config."}, ) activation_dropout: Optional[float] = field( default=None, metadata={"help": "Activation dropout rate. Overwrites config."}, ) attention_dropout: Optional[float] = field( default=None, metadata={"help": "Attention dropout rate. Overwrites config."}, ) def __post_init__(self): if self.tokenizer_name is None: self.tokenizer_name = self.model_name_or_path assert ( self.tokenizer_name is not None ), "Tokenizer name or model name/path needs to be specified" if self.restore_state: assert self.model_name_or_path is not None and ( "/model-" in self.model_name_or_path ), "Restoring state only available with W&B artifact reference" def get_metadata(self): if self.model_name_or_path is not None and ":" in self.model_name_or_path: if jax.process_index() == 0: artifact = wandb.run.use_artifact(self.model_name_or_path) else: artifact = wandb.Api().artifact(self.model_name_or_path) return artifact.metadata else: return dict() def get_opt_state(self): with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies if self.restore_state is True: # wandb artifact state_artifact = self.model_name_or_path.replace( "/model-", "/state-", 1 ) if jax.process_index() == 0: artifact = wandb.run.use_artifact(state_artifact) else: artifact = wandb.Api().artifact(state_artifact) if artifact.metadata.get("bucket_path"): # we will read directly file contents self.restore_state = artifact.metadata["bucket_path"] else: artifact_dir = artifact.download(tmp_dir) self.restore_state = str(Path(artifact_dir) / "opt_state.msgpack") if self.restore_state.startswith("gs://"): bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack" bucket, blob_name = str(bucket_path).split("/", 1) assert ( storage is not None ), 'Could not find google.storage. Install with "pip install google-cloud-storage"' client = storage.Client() bucket = client.bucket(bucket) blob = bucket.blob(blob_name) return blob.download_as_bytes() with Path(self.restore_state).open("rb") as f: return f.read() @dataclass class DataTrainingArguments: """ Arguments pertaining to what data we are going to input our model for training and eval. """ text_column: Optional[str] = field( default="caption", metadata={ "help": "The name of the column in the datasets containing the full texts (for summarization)." }, ) encoding_column: Optional[str] = field( default="encoding", metadata={ "help": "The name of the column in the datasets containing the image encodings." }, ) dataset_repo_or_path: str = field( default=None, metadata={"help": "The dataset repository containing encoded files."}, ) train_file: Optional[str] = field( default=None, metadata={ "help": "The input training data file (glob & braceexpand acceptable)." }, ) validation_file: Optional[str] = field( default=None, metadata={ "help": "An optional input evaluation data file (glob & braceexpand acceptable)." }, ) # data loading should not be a bottleneck so we use "streaming" mode by default streaming: Optional[bool] = field( default=True, metadata={"help": "Whether to stream the dataset."}, ) use_auth_token: Optional[bool] = field( default=False, metadata={ "help": "Whether to use the authentication token for private datasets." }, ) shard_by_host: Optional[bool] = field( default=False, metadata={ "help": "Whether to shard data files by host in multi-host environments." }, ) blank_caption_prob: Optional[float] = field( default=0.0, metadata={ "help": "Probability of removing some captions for classifier-free guidance." }, ) clip_score_column: Optional[str] = field( default="clip_score", metadata={"help": "Column that containts clip score for filtering."}, ) min_clip_score: Optional[float] = field( default=None, metadata={"help": "Minimum clip score required."}, ) max_clip_score: Optional[float] = field( default=None, metadata={"help": "Maximum clip score required."}, ) filter_column: Optional[str] = field( default=None, metadata={"help": "Column that containts classes to be filtered."}, ) filter_value: Optional[str] = field( default=None, metadata={"help": "Class value to be kept during filtering."}, ) multi_eval_ds: Optional[bool] = field( default=False, metadata={ "help": "Whether to look for multiple validation datasets (local support only)." }, ) max_train_samples: Optional[int] = field( default=None, metadata={ "help": "For debugging purposes or quicker training, truncate the number of training examples." }, ) max_eval_samples: Optional[int] = field( default=None, metadata={ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples." }, ) preprocessing_num_workers: Optional[int] = field( default=None, metadata={ "help": "The number of processes to use for the preprocessing. Not used in streaming mode." }, ) overwrite_cache: bool = field( default=False, metadata={ "help": "Overwrite the cached training and evaluation sets. Not used in streaming mode." }, ) # default seed of None ensures we don't repeat the same items if script was interrupted during an epoch seed_dataset: int = field( default=None, metadata={ "help": "Random seed for the dataset that will be set at the beginning of training." }, ) def __post_init__(self): if self.dataset_repo_or_path is None: raise ValueError("Need a dataset repository or path.") @dataclass class TrainingArguments: """ Arguments pertaining to training parameters. """ output_dir: str = field( metadata={ "help": "The output directory where the model predictions and checkpoints will be written." }, ) overwrite_output_dir: bool = field( default=False, metadata={ "help": ( "Overwrite the content of the output directory. " "Use this to continue training if output_dir points to a checkpoint directory." ) }, ) do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) do_eval: bool = field( default=False, metadata={"help": "Whether to run eval on the validation set."} ) per_device_train_batch_size: int = field( default=8, metadata={"help": "Batch size per data parallel device for training."}, ) per_device_eval_batch_size: Optional[int] = field( default=None, metadata={ "help": "Batch size per data parallel device for evaluation. Same as training batch size if not set." }, ) gradient_accumulation_steps: int = field( default=1, metadata={ "help": "Number of updates steps to accumulate before performing an update pass." }, ) gradient_checkpointing: bool = field( default=False, metadata={"help": "Use gradient checkpointing."} ) learning_rate: float = field( default=5e-5, metadata={"help": "The initial learning rate."} ) optim: str = field( default="distributed_shampoo", metadata={ "help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"' }, ) weight_decay: float = field( default=0.0, metadata={"help": "Weight decay applied to parameters."} ) beta1: float = field( default=0.9, metadata={"help": "Beta1 for Adam & Distributed Shampoo."}, ) beta2: float = field( default=0.999, metadata={"help": "Beta2 for for Adam & Distributed Shampoo."}, ) adam_epsilon: float = field( default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."} ) max_grad_norm: float = field( default=1.0, metadata={"help": "Max gradient norm for Adafactor."} ) block_size: int = field( default=1024, metadata={"help": "Chunked size for large layers with Distributed Shampoo."}, ) preconditioning_compute_steps: int = field( default=10, metadata={"help": "Number of steps to update preconditioner."} ) skip_preconditioning_dim_size_gt: int = field( default=4096, metadata={"help": "Max size for preconditioning with Distributed Shampoo."}, ) graft_type: str = field( default="rmsprop_normalized", metadata={ "help": "The type of grafting to use. Can be 'rmsprop_normalized' (default), 'rmsprop', 'adagrad', 'adagrad_normalized', 'sgd' or 'sqrt_n'" }, ) nesterov: bool = field( default=False, metadata={"help": "Use Nesterov momentum for Distributed Shampoo."}, ) optim_quantized: bool = field( default=False, metadata={ "help": "Whether to quantize optimizer (only supported with Distributed Shampoo)." }, ) shard_shampoo_across: str = field( default="dp", metadata={ "help": "Whether to shard the optimizer across data devices (dp), model devices (mp) or both (2d)." }, ) num_train_epochs: int = field( default=3, metadata={"help": "Total number of training epochs to perform."} ) warmup_steps: int = field( default=0, metadata={"help": "Linear warmup over warmup_steps."} ) lr_decay: str = field( default=None, metadata={ "help": "Decay to be used in the learning rate scheduler. Can be None (default), linear or exponential." }, ) lr_transition_steps: int = field( default=None, metadata={ "help": "Number of transition steps associated with learning rate decay when using exponential decay." }, ) lr_decay_rate: float = field( default=None, metadata={ "help": "Decay rate associated with learning rate when using exponential decay." }, ) lr_staircase: bool = field( default=False, metadata={ "help": "Whether to use staircase or continuous learning rate when using exponential decay." }, ) lr_offset: int = field( default=0, metadata={"help": "Number of steps to offset learning rate and keep it at 0."}, ) logging_steps: int = field( default=40, metadata={"help": "Log every X updates steps."} ) eval_steps: int = field( default=400, metadata={"help": "Run an evaluation every X steps."} ) save_steps: int = field( default=4000, metadata={"help": "Save checkpoint every X updates steps."} ) log_model: bool = field( default=False, metadata={"help": "Log model to wandb at `save_steps` frequency."}, ) log_norm_steps: int = field( default=True, metadata={"help": "Log parameters and gradients norm at this frequency."}, ) log_histogram_steps: int = field( default=False, metadata={ "help": "Log parameters and gradients histograms at this frequency. Slows down training." }, ) seed_model: int = field( default=42, metadata={ "help": "Random seed for the model that will be set at the beginning of training." }, ) embeddings_only: bool = field( default=False, metadata={"help": "Train only embedding layers."} ) init_embeddings: bool = field( default=False, metadata={"help": "When training embedding layers, initialize them."}, ) wandb_entity: Optional[str] = field( default=None, metadata={"help": "The wandb entity to use (for teams)."}, ) wandb_project: str = field( default="dalle-mini", metadata={"help": "The name of the wandb project."}, ) wandb_job_type: str = field( default="Seq2Seq", metadata={"help": "The name of the wandb job type."}, ) assert_TPU_available: bool = field( default=False, metadata={"help": "Verify that TPU is not in use."}, ) use_vmap_trick: bool = field( default=True, metadata={"help": "Verify that TPU is not in use."}, ) mp_devices: Optional[int] = field( default=1, metadata={ "help": "Number of devices required for model parallelism. The other dimension of available devices is used for data parallelism." }, ) dp_devices: int = field(init=False) def __post_init__(self): if self.assert_TPU_available: assert ( jax.local_device_count() == 8 ), "TPUs in use, please check running processes" if self.output_dir.startswith("gs://"): assert ( storage is not None ), 'Could not find google.storage. Install with "pip install google-cloud-storage"' assert self.optim in [ "distributed_shampoo", "adam", "adafactor", ], f"Selected optimizer not supported: {self.optim}" if self.optim == "adafactor" and self.weight_decay == 0: self.weight_decay = None assert self.graft_type in [ "rmsprop_normalized", "rmsprop", "adagrad", "adagrad_normalized", "sgd", "sqrt_n", ], f"Selected graft type not supported: {self.graft_type}" assert self.lr_decay in [ None, "linear", "exponential", ], f"Selected learning rate decay not supported: {self.lr_decay}" if self.per_device_eval_batch_size is None: self.per_device_eval_batch_size = self.per_device_train_batch_size if self.log_norm_steps is True: self.log_norm_steps = self.logging_steps if not self.do_train: self.num_train_epochs = 1 if ( os.path.exists(self.output_dir) and os.listdir(self.output_dir) and self.do_train and not self.overwrite_output_dir ): raise ValueError( f"Output directory ({self.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome." ) assert self.shard_shampoo_across in [ "dp", "mp", "2d", ], f"Shard shampoo across {self.shard_shampoo_across} not supported." assert ( self.mp_devices > 0 ), f"Number of devices for model parallelism must be > 0" assert ( jax.device_count() % self.mp_devices == 0 ), f"Number of available devices ({jax.device_count()} must be divisible by number of devices used for model parallelism ({self.mp_devices})." self.dp_devices = jax.device_count() // self.mp_devices def split_params(data): """Split params between scanned and non-scanned""" flat = traverse_util.flatten_dict(unfreeze(data)) split = {"standard": {}, "scanned_encoder": {}, "scanned_decoder": {}} for k, v in flat.items(): if "FlaxBartEncoderLayers" in k: split["scanned_encoder"][k] = v elif "FlaxBartDecoderLayers" in k: split["scanned_decoder"][k] = v else: split["standard"][k] = v # remove empty keys split = {k: v for k, v in split.items() if v} for k, v in split.items(): split[k] = freeze(traverse_util.unflatten_dict(v)) return split def unsplit_params(data): flat = {} for k in ["standard", "scanned_encoder", "scanned_decoder"]: if k in data: flat.update(traverse_util.flatten_dict(unfreeze(data[k]))) return freeze(traverse_util.unflatten_dict(flat)) def trainable_params(data, embeddings_only): """Keep only trainable parameters""" if not embeddings_only: return data data = unfreeze(data) trainable = { "lm_head": data["lm_head"], "model": { "decoder": { layer: data["model"]["decoder"][layer] for layer in [ "embed_positions", "embed_tokens", "final_ln", "layernorm_embedding", ] } }, } return freeze(trainable) def init_embeddings(model, params): """Reinitialize trainable embeddings""" # Must match params in trainable_params() above trainable_keypaths = [ "lm_head.kernel", "model.decoder.embed_positions.embedding", "model.decoder.embed_tokens.embedding", "model.decoder.final_ln.bias", "model.decoder.layernorm_embedding.bias", "model.decoder.layernorm_embedding.scale", ] # Note: using private _missing_keys init_keys = {tuple(k.split(".")) for k in trainable_keypaths} model._missing_keys = init_keys return model.init_weights(model.key, model.input_shape, params=params) def main(): # See all possible arguments by passing the --help flag to this script. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments) ) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1]) ) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() # check arguments if training_args.mp_devices > jax.local_device_count(): assert ( data_args.seed_dataset is not None ), "Seed dataset must be provided when model is split over multiple hosts" # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Load dataset dataset = Dataset( **asdict(data_args), do_train=training_args.do_train, do_eval=training_args.do_eval, ) logger.info(f"Local TPUs: {jax.local_device_count()}") logger.info(f"Global TPUs: {jax.device_count()}") # Set up wandb run if jax.process_index() == 0: wandb.init( entity=training_args.wandb_entity, project=training_args.wandb_project, job_type=training_args.wandb_job_type, config=parser.parse_args(), ) # Set up our new model config config_args = { k: getattr(model_args, k) for k in ["dropout", "activation_dropout", "attention_dropout"] if getattr(model_args, k) is not None } config_args["gradient_checkpointing"] = training_args.gradient_checkpointing if model_args.config_name: config = DalleBartConfig.from_pretrained(model_args.config_name) else: config = None # Load or create new model if model_args.model_name_or_path: model, params = DalleBart.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed_model, dtype=getattr(jnp, model_args.dtype), _do_init=False, ) if training_args.embeddings_only and training_args.init_embeddings: params = init_embeddings(model, params) else: model = DalleBart( config, seed=training_args.seed_model, dtype=getattr(jnp, model_args.dtype), _do_init=False, ) params = None for k, v in config_args.items(): setattr(model.config, k, v) params_shape = model.params_shape_tree # get model metadata model_metadata = model_args.get_metadata() # get PartitionSpec for model params (required to be a dict) param_spec = set_partitions(params_shape, model.config.use_scan) params_shape = freeze(params_shape) if params is not None: params = freeze(params) # Load tokenizer tokenizer = DalleBartTokenizer.from_pretrained( model_args.tokenizer_name, use_fast=True ) # Preprocessing the datasets. # We need to normalize and tokenize inputs and targets. dataset.preprocess(tokenizer=tokenizer, config=model.config) # Initialize our training dropout_rng = jax.random.PRNGKey(training_args.seed_model) # Store some constant num_epochs = training_args.num_train_epochs # batch size batch_size_per_node_per_grad_step = ( training_args.per_device_train_batch_size * jax.local_device_count() // training_args.mp_devices ) batch_size_per_node = ( batch_size_per_node_per_grad_step * training_args.gradient_accumulation_steps ) batch_size_per_step = batch_size_per_node * jax.process_count() eval_batch_size_per_node = ( training_args.per_device_eval_batch_size * jax.local_device_count() // training_args.mp_devices ) eval_batch_size_per_step = eval_batch_size_per_node * jax.process_count() len_train_dataset, len_eval_dataset = dataset.length steps_per_epoch = ( len_train_dataset // batch_size_per_node if len_train_dataset is not None else None ) num_train_steps = ( steps_per_epoch * num_epochs if steps_per_epoch is not None else None ) num_params = model.num_params(params_shape) logger.info("***** Running training *****") logger.info(f" Num examples = {len_train_dataset}") logger.info(f" Num Epochs = {num_epochs}") logger.info( f" Batch size per dp device = {training_args.per_device_train_batch_size}" ) logger.info(f" Number of devices = {jax.device_count()}") logger.info( f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}" ) logger.info(f" Batch size per update = {batch_size_per_step}") logger.info(f" Model parameters = {num_params:,}") # set up wandb run if jax.process_index() == 0: # set default x-axis as 'train/step' wandb.define_metric("*", step_metric="train/step") # add interesting config parameters wandb.config.update( { "len_train_dataset": len_train_dataset, "len_eval_dataset": len_eval_dataset, "batch_size_per_step": batch_size_per_step, "num_params": num_params, "model_config": model.config.to_dict(), "num_devices": jax.device_count(), "versions": { "jax": jax.__version__, "jaxlib": jaxlib.__version__, "flax": flax.__version__, "transformers": transformers.__version__, "datasets": datasets.__version__, "wandb": wandb.__version__, "dalle_mini": dalle_mini.__version__, }, } ) # Create learning rate schedule def create_learning_rate_fn() -> Callable[[int], jnp.array]: """Create the learning rate function.""" warmup_fn = optax.linear_schedule( init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps + 1, # ensure not 0 ) last_boundary = training_args.warmup_steps # offset step when resuming if training_args.lr_offset: warmup_fn = optax.join_schedules( schedules=[optax.constant_schedule(0.0), warmup_fn], boundaries=[training_args.lr_offset], ) last_boundary += training_args.lr_offset if training_args.lr_decay is None: return warmup_fn elif training_args.lr_decay == "linear": assert ( num_train_steps is not None ), "linear decay requires knowing the dataset length" decay_fn = optax.linear_schedule( init_value=training_args.learning_rate, end_value=0, transition_steps=num_train_steps - training_args.warmup_steps, ) elif training_args.lr_decay == "exponential": decay_fn = optax.exponential_decay( init_value=training_args.learning_rate, transition_steps=training_args.lr_transition_steps, decay_rate=training_args.lr_decay_rate, staircase=training_args.lr_staircase, ) schedule_fn = optax.join_schedules( schedules=[warmup_fn, decay_fn], boundaries=[last_boundary], ) return schedule_fn learning_rate_fn = create_learning_rate_fn() # create optimizer trainable_params_shape = trainable_params( params_shape, training_args.embeddings_only ) if training_args.optim == "distributed_shampoo": # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729 graft_type = { "sgd": GraftingType.SGD, "adagrad": GraftingType.ADAGRAD, "rmsprop": GraftingType.RMSPROP, "rmsprop_normalized": GraftingType.RMSPROP_NORMALIZED, "sqrt_n": GraftingType.SQRT_N, "adagrad_normalized": GraftingType.ADAGRAD_NORMALIZED, }[training_args.graft_type] statistics_partition_spec = ( PartitionSpec(None, training_args.shard_shampoo_across, None) if training_args.shard_shampoo_across != "2d" else PartitionSpec(None, "dp", "mp") ) opt = distributed_shampoo( learning_rate_fn, block_size=training_args.block_size, beta1=training_args.beta1, beta2=training_args.beta2, diagonal_epsilon=1e-10, matrix_epsilon=1e-6, weight_decay=training_args.weight_decay, start_preconditioning_step=max( training_args.preconditioning_compute_steps + 1, 101 ), preconditioning_compute_steps=training_args.preconditioning_compute_steps, statistics_compute_steps=1, best_effort_shape_interpretation=True, graft_type=graft_type, nesterov=training_args.nesterov, exponent_override=0, statistics_partition_spec=statistics_partition_spec, preconditioner_partition_spec=PartitionSpec( training_args.shard_shampoo_across, None, None ) if training_args.shard_shampoo_across != "2d" else PartitionSpec( "mp" if training_args.mp_devices > training_args.dp_devices else "dp", None, None, ), num_devices_for_pjit=training_args.dp_devices, shard_optimizer_states=True, inverse_failure_threshold=0.1, moving_average_for_momentum=True, skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt, clip_by_scaled_gradient_norm=None, precision=jax.lax.Precision.HIGHEST, best_effort_memory_usage_reduction=training_args.optim_quantized, ) # get the real optimizer and helper functions update_fn = opt.update optimizer = {} opt_fn = {} for k, p in split_params(trainable_params_shape).items(): if "scanned" in k: p = jax.eval_shape( lambda x: jax.tree_util.tree_map(lambda y: y[0], x), p ) optimizer[k] = opt.init(p) opt_fn[k] = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)( optimizer[k].pspec_fn, optimizer[k].shape_and_dtype_fn ) optimizer[k] = optax.GradientTransformation(optimizer[k].init_fn, update_fn) elif training_args.optim == "adam": optimizer = optax.adamw( learning_rate=learning_rate_fn, b1=training_args.beta1, b2=training_args.beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, ) optimizer = {k: optimizer for k in split_params(trainable_params_shape)} elif training_args.optim == "adafactor": # We use the default parameters here to initialize adafactor, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 optimizer = optax.adafactor( learning_rate=learning_rate_fn, clipping_threshold=training_args.max_grad_norm, weight_decay_rate=training_args.weight_decay, ) optimizer = {k: optimizer for k in split_params(trainable_params_shape)} # get PartitionSpec for optimizer state def get_opt_state_spec_and_shape(): # get opt_state shape without actual init opt_state_shape = {} for k, p in split_params(trainable_params_shape).items(): if "scanned" not in k: opt_state_shape[k] = jax.eval_shape(optimizer[k].init, p) else: opt_state_shape[k] = jax.eval_shape(jax.vmap(optimizer[k].init), p) if training_args.optim == "adafactor": # factorized state must be replicated (rank different than params) opt_state_spec = {k: None for k in split_params(trainable_params_shape)} elif training_args.optim in ["adam", "distributed_shampoo"]: def _opt_state_spec_per_leaf(x, spec): if isinstance(x, FrozenDict): # variables with same structure as params return spec else: # other variables such as count return None split_spec = split_params(set_partitions(trainable_params_shape, False)) opt_state_spec = {} for k, p in split_params(trainable_params_shape).items(): if "scanned" in k: p = jax.eval_shape( lambda x: jax.tree_util.tree_map(lambda y: y[0], x), p ) if training_args.optim == "adam": opt_state_spec[k] = jax.tree_util.tree_map( partial(_opt_state_spec_per_leaf, spec=split_spec[k]), opt_state_shape[k], # return None spec for empty elements is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)), ) elif training_args.optim == "distributed_shampoo": opt_state_spec[k] = opt_fn[k].pspec_fn( p, split_spec[k], statistics_partition_spec, ) # add dimension for scanned params if "scanned" in k: opt_state_spec[k] = jax.tree_util.tree_map( lambda x: PartitionSpec(*(None,) + x) if x is not None else None, opt_state_spec[k], is_leaf=lambda x: isinstance(x, PartitionSpec), ) else: raise NotImplementedError return freeze(opt_state_spec), freeze(opt_state_shape) opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape() # create a mesh mesh_shape = (training_args.dp_devices, training_args.mp_devices) devices = np.asarray(jax.devices()).reshape(*mesh_shape) mesh = maps.Mesh(devices, ("dp", "mp")) logger.info(f" Mesh shape: {mesh_shape}") # define TrainState class TrainState(struct.PyTreeNode): step: int params: core.FrozenDict[str, Any] opt_state: optax.OptState apply_fn: Callable = struct.field(pytree_node=False) tx: optax.GradientTransformation = struct.field(pytree_node=False) dropout_rng: jnp.ndarray = None epoch: int = 0 train_time: float = 0.0 # total time the model trained train_samples: int = 0 # number of samples seen def apply_gradients(self, *, grads, **kwargs): grads = split_params(trainable_params(grads, training_args.embeddings_only)) params = split_params( trainable_params(self.params, training_args.embeddings_only) ) opt_state = {} # we loop over keys: "standard", "scanned_encoder", "scanned_decoder" for k, param in params.items(): update_fn = self.tx[k].update if "scanned" in k: update_fn = jax.vmap(update_fn, in_axes=(0, 0, 0), out_axes=(0, 0)) updates, new_opt_state = update_fn(grads[k], self.opt_state[k], param) params[k] = optax.apply_updates(param, updates) opt_state[k] = new_opt_state params = unsplit_params(params) # merge with non-trainable params params, new_params = traverse_util.flatten_dict( unfreeze(self.params) ), traverse_util.flatten_dict(unfreeze(params)) params.update(new_params) params = freeze(traverse_util.unflatten_dict(params)) return self.replace( step=self.step + 1, params=params, opt_state=freeze(opt_state), **kwargs, ) @classmethod def create(cls, *, apply_fn, params, tx, **kwargs): opt_state = {} for k, p in split_params( trainable_params(params, training_args.embeddings_only) ).items(): init_fn = tx[k].init if "scanned" in k: init_fn = jax.vmap(init_fn) opt_state[k] = init_fn(p) return cls( step=0, apply_fn=apply_fn, params=params, tx=tx, opt_state=freeze(opt_state), **kwargs, ) # define state spec state_spec = TrainState( params=param_spec, opt_state=opt_state_spec, dropout_rng=None, step=None, epoch=None, train_time=None, train_samples=None, apply_fn=model.__call__, tx=optimizer, ) # init params if not available yet def maybe_init_params(params): if params is not None: # model params are correctly loaded return params else: # params have not been initialized yet return model.init_weights(model.key, model.input_shape) with mesh: logger.info(" Creating state") # restore metadata attr_state = {} keys = ["train_time", "train_samples"] if model_args.restore_state: keys += ["step", "epoch"] attr_state = {k: v for k, v in model_metadata.items() if k in keys} if not model_args.restore_state: def init_state(params): return TrainState.create( apply_fn=model.__call__, tx=optimizer, params=maybe_init_params(params), dropout_rng=dropout_rng, **attr_state, ) state = pjit( init_state, in_axis_resources=(param_spec,) if model_args.model_name_or_path else None, out_axis_resources=state_spec, donate_argnums=(0,), )(params) else: # load opt_state opt_state = from_bytes(opt_state_shape, model_args.get_opt_state()) def restore_state(params, opt_state): return TrainState( apply_fn=model.__call__, tx=optimizer, params=params, opt_state=opt_state, dropout_rng=dropout_rng, **attr_state, ) state = pjit( restore_state, in_axis_resources=( param_spec, opt_state_spec, ), out_axis_resources=state_spec, donate_argnums=(0, 1), )(params, opt_state) # remove opt_state from CPU del opt_state # free CPU memory del params, opt_state_spec, opt_state_shape # define batch specs batch_spec = PartitionSpec("dp") grad_batch_spec = PartitionSpec(None, "dp") # define loss def loss_fn(logits, labels): loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) loss = loss.mean() return loss # "vmap trick" avoids a crash when mp_devices > 1 (not sure why it happens) # lead to better perf: see https://wandb.ai/dalle-mini/dalle-mini/reports/JAX-pmap-vs-pjit--VmlldzoxNDg1ODA2 use_vmap_trick = training_args.use_vmap_trick # make grad_param_spec for vmap if use_vmap_trick: grad_param_spec = jax.tree_util.tree_map( lambda x: PartitionSpec(*("dp",) + (x if x is not None else (None,))), param_spec, ) # Define gradient update step fn def train_step(state, batch, train_time): # get a minibatch (one gradient accumulation slice) def get_minibatch(batch, grad_idx): return jax.tree_util.tree_map( lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False), batch, ) def compute_loss(params, minibatch, dropout_rng): # minibatch has dim (batch_size, ...) minibatch, labels = minibatch.pop("labels") logits = state.apply_fn( **minibatch, params=params, dropout_rng=dropout_rng, train=True )[0] return loss_fn(logits, labels) grad_fn = jax.value_and_grad(compute_loss) def loss_and_grad(grad_idx, dropout_rng): # minibatch at grad_idx for gradient accumulation (None otherwise) minibatch = ( get_minibatch(batch, grad_idx) if grad_idx is not None else batch ) # ensure it is sharded properly minibatch = with_sharding_constraint(minibatch, batch_spec) # only 1 single rng per grad step, let us handle larger batch size (not sure why) dropout_rng, _ = jax.random.split(dropout_rng) if use_vmap_trick: # "vmap trick", calculate loss and grads independently per dp_device loss, grads = jax.vmap( grad_fn, in_axes=(None, 0, None), out_axes=(0, 0) )(state.params, minibatch, dropout_rng) # ensure they are sharded correctly loss = with_sharding_constraint(loss, batch_spec) grads = with_sharding_constraint(grads, grad_param_spec) # average across all devices # Note: we could average per device only after gradient accumulation, right before params update loss, grads = jax.tree_util.tree_map( lambda x: jnp.mean(x, axis=0), (loss, grads) ) else: # "vmap trick" does not work in multi-hosts and requires too much hbm loss, grads = grad_fn(state.params, minibatch, dropout_rng) # ensure grads are sharded grads = with_sharding_constraint(grads, param_spec) # return loss and grads return loss, grads, dropout_rng if training_args.gradient_accumulation_steps == 1: loss, grads, dropout_rng = loss_and_grad(None, state.dropout_rng) else: # create initial state for cumul_minibatch_step loop init_minibatch_step = ( 0.0, with_sharding_constraint( jax.tree_util.tree_map(jnp.zeros_like, state.params), param_spec ), state.dropout_rng, ) # accumulate gradients def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout): cumul_loss, cumul_grads, dropout_rng = cumul_loss_grad_dropout loss, grads, dropout_rng = loss_and_grad(grad_idx, dropout_rng) cumul_loss, cumul_grads = jax.tree_util.tree_map( jnp.add, (cumul_loss, cumul_grads), (loss, grads) ) cumul_grads = with_sharding_constraint(cumul_grads, param_spec) return cumul_loss, cumul_grads, dropout_rng # loop over gradients loss, grads, dropout_rng = jax.lax.fori_loop( 0, training_args.gradient_accumulation_steps, cumul_minibatch_step, init_minibatch_step, ) grads = with_sharding_constraint(grads, param_spec) # sum -> mean loss, grads = jax.tree_util.tree_map( lambda x: x / training_args.gradient_accumulation_steps, (loss, grads) ) grads = with_sharding_constraint(grads, param_spec) # update state state = state.apply_gradients( grads=grads, dropout_rng=dropout_rng, train_time=train_time, train_samples=state.train_samples + batch_size_per_step, ) metrics = { "loss": loss, "learning_rate": learning_rate_fn(state.step), } def maybe_fn(fn, val, zeros, freq): """Call fn only if it is a logging step""" return jax.lax.cond( state.step % freq == 0, fn, lambda _: zeros, val, ) # log additional metrics params = trainable_params(state.params, training_args.embeddings_only) grads = trainable_params(grads, training_args.embeddings_only) if training_args.log_norm_steps: zeros_norm = jax.tree_util.tree_map(lambda _: jnp.float32(0), params) def norm(val): return jax.tree_util.tree_map(lambda x: jnp.linalg.norm(x), val) gradients_norm = maybe_fn( norm, grads, zeros_norm, training_args.log_norm_steps ) params_norm = maybe_fn( norm, params, zeros_norm, training_args.log_norm_steps ) metrics.update( { "gradients_norm": gradients_norm, "params_norm": params_norm, } ) if training_args.log_histogram_steps: zeros_hist = jax.tree_util.tree_map( lambda _: jnp.histogram(jnp.zeros(1), density=True), params ) def histogram(val): return jax.tree_util.tree_map( lambda x: jnp.histogram(x, density=True), val ) gradients_hist = maybe_fn( histogram, grads, zeros_hist, training_args.log_histogram_steps ) params_hist = maybe_fn( histogram, params, zeros_hist, training_args.log_histogram_steps ) metrics.update( { "params_hist": params_hist, "gradients_hist": gradients_hist, } ) return state, metrics # Define eval fn eval_model = ( model if model_args.dtype == "float32" else DalleBart( model.config, seed=training_args.seed_model, dtype=jnp.float32, _do_init=False, ) ) def eval_step(state, batch): def compute_eval_loss(batch): batch, labels = batch.pop("labels") logits = eval_model(**batch, params=state.params, train=False)[0] return loss_fn(logits, labels) if use_vmap_trick: loss = jax.vmap(compute_eval_loss)(batch) # ensure they are sharded correctly loss = with_sharding_constraint(loss, batch_spec) # average across all devices loss = jnp.mean(loss) else: loss = compute_eval_loss(batch) return loss # Create parallel version of the train and eval step p_train_step = pjit( train_step, in_axis_resources=( state_spec, grad_batch_spec if training_args.gradient_accumulation_steps > 1 else batch_spec, None, ), out_axis_resources=(state_spec, None), donate_argnums=(0,), ) p_eval_step = pjit( eval_step, in_axis_resources=(state_spec, batch_spec), out_axis_resources=None, ) # define metrics logger class MetricsLogger: def __init__(self, step): # keep state self.state_dict = {} # estimate speed self.step = step self.time = time.perf_counter() self.offset_time = 0.0 def update_state_metrics(self, state): """Update internal state metrics (logged at each call to be used as x-axis)""" self.state_dict = { f'train/{k.split("_")[-1]}': state[k] for k in ["step", "epoch", "train_time", "train_samples"] } # timing metrics new_step = int(state["step"]) new_time = time.perf_counter() if new_step > self.step: # remove time for eval & save delta_time = new_time - self.time - self.offset_time self.offset_time = 0 time_per_step = delta_time / (new_step - self.step) self.step = new_step self.time = new_time self.log_time("train_per_step", time_per_step, offset=False) self.log_time("train_per_log", delta_time, offset=False) def log_time(self, key, duration, offset=True): if jax.process_index() == 0: wandb.log({f"time/{key}": duration, **self.state_dict}) if offset: self.offset_time += duration def log(self, metrics, prefix=None): if jax.process_index() == 0: log_metrics = {} for k, v in metrics.items(): if "_norm" in k: if self.step % training_args.log_norm_steps == 0: log_metrics[f"{k}/"] = unfreeze(v) elif "_hist" in k: if self.step % training_args.log_histogram_steps == 0: v = jax.tree_util.tree_map( lambda x: jax.device_get(x), unfreeze(v) ) v = jax.tree_util.tree_map( lambda x: wandb.Histogram(np_histogram=x), v, is_leaf=lambda x: isinstance(x, tuple), ) log_metrics[f"{k}/"] = v else: if prefix is not None: k = f"{prefix}/{k}" log_metrics[k] = v wandb.log({**log_metrics, **self.state_dict}) # keep local copy of state local_state = { k: jax.device_get(getattr(state, k)).item() for k in ["step", "epoch", "train_time", "train_samples"] } # init variables start_time = time.perf_counter() - local_state["train_time"] train_metrics = None evaluation_ran = False save_model_ran = False metrics_logger = MetricsLogger(local_state["step"]) epochs = tqdm( range(local_state["epoch"], num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0, disable=jax.process_index() > 0, ) def run_evaluation(): # ======================== Evaluating ============================== if training_args.do_eval: start_eval_time = time.perf_counter() # get validation datasets val_datasets = list( dataset.other_eval_datasets.keys() if hasattr(dataset, "other_eval_datasets") else [] ) val_datasets += ["eval"] for val_dataset in val_datasets: eval_loader = dataset.dataloader( val_dataset, eval_batch_size_per_step * max(1, training_args.mp_devices // jax.local_device_count()), ) eval_steps = ( len_eval_dataset // eval_batch_size_per_step if len_eval_dataset is not None else None ) eval_loss = [] for batch in tqdm( eval_loader, desc="Evaluating...", position=2, leave=False, total=eval_steps, disable=jax.process_index() > 0, ): # need to keep only eval_batch_size_per_node items relevant to the node batch = jax.tree_util.tree_map( lambda x: x.reshape( (jax.process_count(), eval_batch_size_per_node) + x.shape[1:] ), batch, ) batch = jax.tree_util.tree_map( lambda x: x[jax.process_index()], batch ) # add dp dimension when using "vmap trick" if use_vmap_trick: bs_shape = ( jax.local_device_count() // training_args.mp_devices, training_args.per_device_eval_batch_size, ) batch = jax.tree_util.tree_map( lambda x: x.reshape(bs_shape + x.shape[1:]), batch ) # freeze batch to pass safely to jax transforms batch = freeze(batch) # accumulate losses async eval_loss.append(p_eval_step(state, batch)) # get the mean of the loss eval_loss = jnp.stack(eval_loss) eval_loss = jnp.mean(eval_loss) eval_metrics = {"loss": eval_loss} # log metrics metrics_logger.log(eval_metrics, prefix=val_dataset) # Print metrics and update progress bar desc = f"Epoch... ({epoch + 1}/{num_epochs} | {val_dataset} Loss: {eval_metrics['loss']})" epochs.write(desc) epochs.desc = desc # log time metrics_logger.log_time("eval", time.perf_counter() - start_eval_time) return eval_metrics def run_save_model(state, eval_metrics=None): if jax.process_index() == 0: start_save_time = time.perf_counter() output_dir = training_args.output_dir use_bucket = output_dir.startswith("gs://") if use_bucket: bucket_path = Path(output_dir[5:]) / wandb.run.id / f"step_{state.step}" bucket, dir_path = str(bucket_path).split("/", 1) tmp_dir = tempfile.TemporaryDirectory() output_dir = tmp_dir.name # save model params = jax.device_get(state.params) model.save_pretrained( output_dir, params=params, ) # save tokenizer tokenizer.save_pretrained(output_dir) # copy to bucket if use_bucket: client = storage.Client() bucket = client.bucket(bucket) for filename in Path(output_dir).glob("*"): blob_name = str(Path(dir_path) / "model" / filename.name) blob = bucket.blob(blob_name) blob.upload_from_filename(str(filename)) tmp_dir.cleanup() # save state opt_state = jax.device_get(state.opt_state) if use_bucket: blob_name = str(Path(dir_path) / "state" / "opt_state.msgpack") blob = bucket.blob(blob_name) blob.upload_from_file(io.BytesIO(to_bytes(opt_state))) else: with (Path(output_dir) / "opt_state.msgpack").open("wb") as f: f.write(to_bytes(opt_state)) # save to W&B if training_args.log_model: # save some space c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache() c.cleanup(wandb.util.from_human_size("20GB")) metadata = { k: jax.device_get(getattr(state, k)).item() for k in ["step", "epoch", "train_time", "train_samples"] } metadata["num_params"] = num_params if eval_metrics is not None: metadata["eval"] = eval_metrics # create model artifact if use_bucket: metadata["bucket_path"] = f"gs://{bucket_path}/model" artifact = wandb.Artifact( name=f"model-{wandb.run.id}", type="DalleBart_model", metadata=metadata, ) if use_bucket: artifact.add_reference(metadata["bucket_path"]) else: for filename in [ "config.json", "flax_model.msgpack", "merges.txt", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.json", ]: artifact.add_file( f"{Path(training_args.output_dir) / filename}" ) wandb.run.log_artifact(artifact) # create state artifact if use_bucket: metadata["bucket_path"] = f"gs://{bucket_path}/state" artifact_state = wandb.Artifact( name=f"state-{wandb.run.id}", type="DalleBart_state", metadata=metadata, ) if use_bucket: artifact_state.add_reference(metadata["bucket_path"]) else: artifact_state.add_file( f"{Path(training_args.output_dir) / 'opt_state.msgpack'}" ) wandb.run.log_artifact(artifact_state) metrics_logger.log_time("save_model", time.perf_counter() - start_save_time) logger.info(" Ready to start training") with mesh: for epoch in epochs: state = state.replace(epoch=epoch) local_state["epoch"] = epoch # ======================== Training ================================ metrics_logger.update_state_metrics(local_state) metrics_logger.log({}) if training_args.do_train: # load data - may be replicated on multiple nodes node_groups = max( 1, training_args.mp_devices // jax.local_device_count() ) loader_bs = batch_size_per_node * node_groups train_loader = dataset.dataloader( "train", loader_bs, epoch, ) # train for batch in tqdm( train_loader, desc="Training...", position=1, leave=False, total=steps_per_epoch, disable=jax.process_index() > 0, ): # calculate delta time (we have a lag of one step but it's ok) train_time = time.perf_counter() - start_time # reset control variables evaluation_ran = False save_model_ran = False # set correct shape to batch # - add grad_step dim if gradient_accumulation_steps > 1 bs_shape = ( (batch_size_per_node_per_grad_step * node_groups,) if not use_vmap_trick else ( jax.local_device_count() * node_groups // training_args.mp_devices, # local dp devices training_args.per_device_train_batch_size, ) ) if training_args.gradient_accumulation_steps > 1: # reshape data into (gradient_accumulation_steps, batch_per_node, ...) # to avoid any data redistribution when sharding bs_shape = ( training_args.gradient_accumulation_steps, ) + bs_shape # reshape batch batch = jax.tree_util.tree_map( lambda x: x.reshape(bs_shape + x.shape[1:]), batch, ) # freeze batch to pass safely to jax transforms batch = freeze(batch) # train step state, train_metrics = p_train_step(state, batch, train_time) local_state["step"] += 1 local_state["train_time"] = train_time local_state["train_samples"] += batch_size_per_step if ( local_state["step"] % training_args.logging_steps == 0 and jax.process_index() == 0 ): metrics_logger.update_state_metrics(local_state) metrics_logger.log(train_metrics, prefix="train") eval_metrics = None if local_state["step"] % training_args.eval_steps == 0: eval_metrics = run_evaluation() evaluation_ran = True if local_state["step"] % training_args.save_steps == 0: run_save_model(state, eval_metrics) save_model_ran = True # log final train metrics if train_metrics is not None: metrics_logger.update_state_metrics(local_state) metrics_logger.log(train_metrics, prefix="train") epochs.write( f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})" ) # Final evaluation at the end of each epoch if not evaluation_ran: eval_metrics = run_evaluation() # save checkpoint after each epoch if not save_model_ran: run_save_model(state, eval_metrics) if __name__ == "__main__": main()