Full Code of borisdayma/dalle-mini for AI

main f0be4de61028 cached
48 files
392.9 KB
96.0k tokens
182 symbols
1 requests
Download .txt
Showing preview only (411K chars total). Download the full file or copy to clipboard to get everything.
Repository: borisdayma/dalle-mini
Branch: main
Commit: f0be4de61028
Files: 48
Total size: 392.9 KB

Directory structure:
gitextract_0gdudwyh/

├── .gitattributes
├── .github/
│   ├── FUNDING.yml
│   └── workflows/
│       ├── check_size.yml
│       ├── pypi_release.yml
│       ├── style.yml
│       ├── sync_to_hub.yml.backup
│       └── sync_to_hub_debug.yml
├── .gitignore
├── CITATION.cff
├── Docker/
│   ├── Dockerfile
│   ├── README.md
│   └── build_docker.sh
├── LICENSE
├── Makefile
├── README.md
├── app/
│   ├── gradio/
│   │   ├── app.py
│   │   └── backend.py
│   └── streamlit/
│       ├── app.py
│       └── backend.py
├── pyproject.toml
├── run_docker_image.sh
├── setup.cfg
├── setup.py
├── src/
│   └── dalle_mini/
│       ├── __init__.py
│       ├── data.py
│       └── model/
│           ├── __init__.py
│           ├── configuration.py
│           ├── modeling.py
│           ├── partitions.py
│           ├── processor.py
│           ├── text.py
│           ├── tokenizer.py
│           └── utils.py
└── tools/
    ├── dataset/
    │   └── encode_dataset.ipynb
    ├── inference/
    │   ├── inference_pipeline.ipynb
    │   └── run_infer_notebook.sh
    └── train/
        ├── config/
        │   ├── mega/
        │   │   └── config.json
        │   ├── micro/
        │   │   └── config.json
        │   ├── mini/
        │   │   └── config.json
        │   └── mini_glu/
        │       └── config.json
        ├── embeddings_retrain_preparation.ipynb
        ├── scalable_shampoo/
        │   ├── README.md
        │   ├── distributed_shampoo.py
        │   ├── quantization_utils.py
        │   ├── sm3.py
        │   └── symmetric_matrices/
        │       └── symmetric_matrices.py
        ├── sweep.yaml
        └── train.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitattributes
================================================
*.bin.* filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tar.gz filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text


================================================
FILE: .github/FUNDING.yml
================================================
github: [borisdayma]


================================================
FILE: .github/workflows/check_size.yml
================================================
name: Check file size

on:
  pull_request:
    branches: [main]

  # to run this workflow manually from the Actions tab
  workflow_dispatch:

jobs:
  sync-to-hub:
    runs-on: ubuntu-latest
    steps:
      - name: Check large files
        uses: ActionsDesk/lfs-warning@v2.0
        with:
          filesizelimit: 10485760 # = 10MB, so we can sync to HF spaces


================================================
FILE: .github/workflows/pypi_release.yml
================================================
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.

name: Upload Python Package

on:
  release:
    types: [published]

jobs:
  deploy:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v3
      - name: Set up Python
        uses: actions/setup-python@v3
        with:
          python-version: "3.x"
      - name: Install dependencies
        run: |
          python -m pip install --upgrade pip
          pip install build
      - name: Build package
        run: python -m build
      - name: Publish package
        uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
        with:
          user: __token__
          password: ${{ secrets.PYPI_API_TOKEN }}


================================================
FILE: .github/workflows/style.yml
================================================
name: Lint

on:
  push:
    branches: [main]
  pull_request:
    branches: [main]

jobs:
  lint:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v2
      - uses: psf/black@stable
      - uses: actions/setup-python@v2
        with:
          python-version: 3.9
      - name: Install requirements
        run: pip install ".[dev]"
      - uses: jamescurtin/isort-action@master


================================================
FILE: .github/workflows/sync_to_hub.yml.backup
================================================
name: Sync to Hugging Face hub - Obsolete to avoid app disruptions

on:
  push:
    branches: [main]

  # to run this workflow manually from the Actions tab
  workflow_dispatch:

jobs:
  sync-to-hub:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v2
        with:
          fetch-depth: 0
      - name: Push to hub
        env:
          HF_TOKEN: ${{ secrets.HF_TOKEN }}
        run: git push https://boris:$HF_TOKEN@huggingface.co/spaces/dalle-mini/dalle-mini main


================================================
FILE: .github/workflows/sync_to_hub_debug.yml
================================================
name: Deploy to debug app

on:
  # to run this workflow manually from the Actions tab
  workflow_dispatch:

jobs:
  sync-to-hub-debug:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v2
        with:
          fetch-depth: 0
      - name: Push to hub
        env:
          HF_TOKEN: ${{ secrets.HF_TOKEN }}
        run: git push --force https://boris:$HF_TOKEN@huggingface.co/spaces/dalle-mini/dalle-mini-debug +HEAD:main


================================================
FILE: .gitignore
================================================
__pycache__
.ipynb_checkpoints
.streamlit
wandb/
*.egg-info/
jax_cache/


================================================
FILE: CITATION.cff
================================================
# YAML 1.2
---
abstract: "DALL·E mini is a JAX/Flax reimplementation of OpenAI's DALL·E that requires much smaller hardware resources. By simplifying the architecture and model memory requirements, as well as leveraging open-source code and pre-trained models, we were able to create a model that is 27 times smaller than the original DALL·E and train it on a single TPU v3-8 for only 3 days. DALL·E mini achieves impressive results, albeit of a lower quality than the original system. It can be used for exploration and further experimentation on commodity hardware."
authors: 
  -
    family-names: Dayma
    given-names: Boris
  -
    family-names: Patil
    given-names: Suraj
  -
    family-names: Cuenca
    given-names: Pedro
  -
    family-names: Saifullah
    given-names: Khalid
  -
    family-names: Abraham
    given-names: Tanishq
  -
    family-names: "Lê Khắc"
    given-names: "Phúc"
  -
    family-names: Melas
    given-names: Luke
  -
    family-names: Ghosh
    given-names: Ritobrata
cff-version: "1.1.0"
date-released: 2021-07-29
identifiers: 
keywords: 
  - dalle
  - "text-to-image generation"
  - transformer
  - "zero-shot"
  - JAX
license: "Apache-2.0"
doi: 10.5281/zenodo.5146400
message: "If you use this project, please cite it using these metadata."
repository-code: "https://github.com/borisdayma/dalle-mini"
title: "DALL·E Mini"
version: "v0.1-alpha"
...

================================================
FILE: Docker/Dockerfile
================================================
FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04

RUN apt-get update && apt-get install -y \
  git \
  python3 \
  python3-pip \
  && rm -rf /var/lib/apt/lists/*

RUN pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
  && pip install -q \
  git+https://github.com/borisdayma/dalle-mini.git \
  git+https://github.com/patil-suraj/vqgan-jax.git

RUN pip install jupyter

WORKDIR /workspace



================================================
FILE: Docker/README.md
================================================
# Running Dalle-mini With Docker

This folder contains the Dockerfile needed to build a Docker image that can easily run Dalle-mini.

## Inference

Steps to run inference with Dalle-mini are as follows:

1. Build the docker image with ```dalle-mini/Docker/build_docker.sh```
2. Run the container with ```dalle-mini/run_docker_image.sh```
3. Navigate to ```/workspace/tools/inference/``` and run ```run_infer_notebook.sh```
4. Click the Jupyter Notebook link and run through the notebook.

### Inference Video Tutorial

Alteratively check out a video tutorial on how to run Dalle-mini on [Linux](https://www.youtube.com/watch?v=eWpzLIa6v9E) and [Windows](https://www.youtube.com/watch?v=OqEuEe-xSKk)


================================================
FILE: Docker/build_docker.sh
================================================
docker build . -t dalle-mini:latest


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright 2021 The DALL·E mini Authors

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: Makefile
================================================
.PHONY: style

style:
	black .
	isort .

================================================
FILE: README.md
================================================
# DALL·E Mini

<a href="https://www.craiyon.com/"><img src="https://www.craiyon.com/thumbnail.png" width="300"></a>

## How to use it?

You can use the model on [🖍️ craiyon](https://www.craiyon.com/)

## How does it work?

Refer to our reports:

* [DALL·E mini - Generate Images from Any Text Prompt](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini-Generate-images-from-any-text-prompt--VmlldzoyMDE4NDAy)
* [DALL·E mini - Explained](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mini-Explained-with-Demo--Vmlldzo4NjIxODA)
* [DALL·E mega - Training Journal](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mega-Training-Journal--VmlldzoxODMxMDI2)

## Development

### Dependencies Installation

For inference only, use `pip install dalle-mini`.

For development, clone the repo and use `pip install -e ".[dev]"`.
Before making a PR, check style with `make style`.

You can experiment with the pipeline step by step through our [`inference pipeline notebook`](tools/inference/inference_pipeline.ipynb)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb)

### Training of DALL·E mini

Use [`tools/train/train.py`](tools/train/train.py).

You can also adjust the [sweep configuration file](https://docs.wandb.ai/guides/sweeps) if you need to perform a hyperparameter search.

## FAQ

### Where to find the latest models?

Trained models are on 🤗 Model Hub:

* [VQGAN-f16-16384](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384) for encoding/decoding images
* [DALL·E mini](https://huggingface.co/dalle-mini/dalle-mini) or [DALL·E mega](https://huggingface.co/dalle-mini/dalle-mega) for generating images from a text prompt

### Where does the logo come from?

The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL·E to illustrate the model's capabilities. Having successful predictions on this prompt represents a big milestone for us.

## Contributing

Join the community on the [LAION Discord](https://discord.gg/xBPBXfcFHd).
Any contribution is welcome, from reporting issues to proposing fixes/improvements or testing the model with cool prompts!

You can also use these great projects from the community:

* spin off your own app with [DALL-E Playground repository](https://github.com/saharmor/dalle-playground) (thanks [Sahar](https://twitter.com/theaievangelist))

* try [DALL·E Flow](https://github.com/jina-ai/dalle-flow) project for generating, diffusion, and upscaling in a Human-in-the-Loop workflow (thanks [Han Xiao](https://github.com/hanxiao))

  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jina-ai/dalle-flow/blob/main/client.ipynb)

* run on [Replicate](https://replicate.com/borisdayma/dalle-mini), in the browser or via API

## Acknowledgements

* 🤗 Hugging Face for organizing [the FLAX/JAX community week](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects)
* Google [TPU Research Cloud (TRC) program](https://sites.research.google/trc/) for providing computing resources
* [Weights & Biases](https://wandb.com/) for providing the infrastructure for experiment tracking and model management

## Authors & Contributors

DALL·E mini was initially developed by:

* [Boris Dayma](https://github.com/borisdayma)
* [Suraj Patil](https://github.com/patil-suraj)
* [Pedro Cuenca](https://github.com/pcuenca)
* [Khalid Saifullah](https://github.com/khalidsaifullaah)
* [Tanishq Abraham](https://github.com/tmabraham)
* [Phúc Lê Khắc](https://github.com/lkhphuc)
* [Luke Melas](https://github.com/lukemelas)
* [Ritobrata Ghosh](https://github.com/ghosh-r)

Many thanks to the people who helped make it better:

* the [DALLE-Pytorch](https://discord.gg/xBPBXfcFHd) and [EleutherAI](https://www.eleuther.ai/) communities for testing and exchanging cool ideas
* [Rohan Anil](https://github.com/rohan-anil) for adding Distributed Shampoo optimizer and always giving great suggestions
* [Phil Wang](https://github.com/lucidrains) has provided a lot of cool implementations of transformer variants and gives interesting insights with [x-transformers](https://github.com/lucidrains/x-transformers)
* [Katherine Crowson](https://github.com/crowsonkb) for [super conditioning](https://twitter.com/RiversHaveWings/status/1478093658716966912)
* the [Gradio team](https://gradio.app/) made an amazing UI for our app

## Citing DALL·E mini

If you find DALL·E mini useful in your research or wish to refer, please use the following BibTeX entry.

```text
@misc{Dayma_DALL·E_Mini_2021,
      author = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata},
      doi = {10.5281/zenodo.5146400},
      month = {7},
      title = {DALL·E Mini},
      url = {https://github.com/borisdayma/dalle-mini},
      year = {2021}
}
```

## References

Original DALL·E from "[Zero-Shot Text-to-Image Generation](https://arxiv.org/abs/2102.12092)" with image quantization from "[Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)".

Image encoder from "[Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841v2)".

Sequence to sequence model based on "[BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461v1)" with implementation of a few variants:

* "[GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202)"
* "[Deepnet: Scaling Transformers to 1,000 Layers](https://arxiv.org/abs/2203.00555)"
* "[NormFormer: Improved Transformer Pretraining with Extra Normalization](https://arxiv.org/abs/2110.09456)"
* "[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)"
* "[CogView: Mastering Text-to-Image Generation via Transformers](https://arxiv.org/abs/2105.13290v2)"
* "[Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)"
* "[Sinkformers: Transformers with Doubly Stochastic Attention](https://arxiv.org/abs/2110.11773)"
* "[Foundation Transformers](https://arxiv.org/abs/2210.06423)

Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization for Deep Learning](https://arxiv.org/abs/2002.09018)".

### Citations

```text
@misc{
  title={Zero-Shot Text-to-Image Generation}, 
  author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
  year={2021},
  eprint={2102.12092},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}
```

```text
@misc{
  title={Learning Transferable Visual Models From Natural Language Supervision}, 
  author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
  year={2021},
  eprint={2103.00020},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}
```

```text
@misc{
  title={Taming Transformers for High-Resolution Image Synthesis}, 
  author={Patrick Esser and Robin Rombach and Björn Ommer},
  year={2021},
  eprint={2012.09841},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}
```

```text
@misc{
  title={BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension}, 
  author={Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and Abdelrahman Mohamed and Omer Levy and Ves Stoyanov and Luke Zettlemoyer},
  year={2019},
  eprint={1910.13461},
  archivePrefix={arXiv},
  primaryClass={cs.CL}
}
```

```text
@misc{
  title={Scalable Second Order Optimization for Deep Learning},
  author={Rohan Anil and Vineet Gupta and Tomer Koren and Kevin Regan and Yoram Singer},
  year={2021},
  eprint={2002.09018},
  archivePrefix={arXiv},
  primaryClass={cs.LG}
}
```

```text
@misc{
  title={GLU Variants Improve Transformer},
  author={Noam Shazeer},
  year={2020},
  url={https://arxiv.org/abs/2002.05202}    
}
```

```text
 @misc{
  title={DeepNet: Scaling transformers to 1,000 layers},
  author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Zhang, Dongdong and Wei, Furu},
  year={2022},
  eprint={2203.00555}
  archivePrefix={arXiv},
  primaryClass={cs.LG}
} 
```

```text
@misc{
  title={NormFormer: Improved Transformer Pretraining with Extra Normalization},
  author={Sam Shleifer and Jason Weston and Myle Ott},
  year={2021},
  eprint={2110.09456},
  archivePrefix={arXiv},
  primaryClass={cs.CL}
}
```

```text
@inproceedings{
  title={Swin Transformer V2: Scaling Up Capacity and Resolution}, 
  author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
  booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2022}
}
```

```text
@misc{
  title = {CogView: Mastering Text-to-Image Generation via Transformers},
  author = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
  year = {2021},
  eprint = {2105.13290},
  archivePrefix = {arXiv},
  primaryClass = {cs.CV}
}
```

```text
@misc{
  title = {Root Mean Square Layer Normalization},
  author = {Biao Zhang and Rico Sennrich},
  year = {2019},
  eprint = {1910.07467},
  archivePrefix = {arXiv},
  primaryClass = {cs.LG}
}
```

```text
@misc{
  title = {Sinkformers: Transformers with Doubly Stochastic Attention},
  url = {https://arxiv.org/abs/2110.11773},
  author = {Sander, Michael E. and Ablin, Pierre and Blondel, Mathieu and Peyré, Gabriel},
  publisher = {arXiv},
  year = {2021},
}
```

```text
@misc{
  title = {Smooth activations and reproducibility in deep networks},
  url = {https://arxiv.org/abs/2010.09931},
  author = {Shamir, Gil I. and Lin, Dong and Coviello, Lorenzo},
  publisher = {arXiv},
  year = {2020},
}
```

```text
@misc{
  title = {Foundation Transformers},
  url = {https://arxiv.org/abs/2210.06423},
  author = {Wang, Hongyu and Ma, Shuming and Huang, Shaohan and Dong, Li and Wang, Wenhui and Peng, Zhiliang and Wu, Yu and Bajaj, Payal and Singhal, Saksham and Benhaim, Alon and Patra, Barun and Liu, Zhun and Chaudhary, Vishrav and Song, Xia and Wei, Furu},
  publisher = {arXiv},
  year = {2022},
}
```


================================================
FILE: app/gradio/app.py
================================================
#!/usr/bin/env python
# coding: utf-8
import os

import gradio as gr
from backend import get_images_from_backend

block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }")
backend_url = os.environ["BACKEND_SERVER"] + "/generate"


def infer(prompt):
    response = get_images_from_backend(prompt, backend_url)
    return response["images"]


with block:
    gr.Markdown("<h1><center>DALL·E mini</center></h1>")
    gr.Markdown(
        "DALL·E mini is an AI model that generates images from any prompt you give!"
    )
    with gr.Group():
        with gr.Box():
            with gr.Row().style(mobile_collapse=False, equal_height=True):
                text = gr.Textbox(
                    label="Enter your prompt", show_label=False, max_lines=1
                ).style(
                    border=(True, False, True, True),
                    margin=False,
                    rounded=(True, False, False, True),
                    container=False,
                )
                btn = gr.Button("Run").style(
                    margin=False,
                    rounded=(False, True, True, False),
                )
        gallery = gr.Gallery(label="Generated images", show_label=False).style(
            grid=[3], height="auto"
        )
        text.submit(infer, inputs=text, outputs=gallery)
        btn.click(infer, inputs=text, outputs=gallery)

    gr.Markdown(
        """___
   <p style='text-align: center'>
   Created by <a href="https://twitter.com/borisdayma" target="_blank">Boris Dayma</a> et al. 2021-2022
   <br/>
   <a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini-Generate-images-from-any-text-prompt--VmlldzoyMDE4NDAy" target="_blank">Project Report</a>
   </p>"""
    )


block.launch(enable_queue=False)


================================================
FILE: app/gradio/backend.py
================================================
# Client requests to Dalle-Mini Backend server

import base64
from io import BytesIO

import requests
from PIL import Image


class ServiceError(Exception):
    def __init__(self, status_code):
        self.status_code = status_code


def get_images_from_backend(prompt, backend_url):
    r = requests.post(backend_url, json={"prompt": prompt})
    if r.status_code == 200:
        json = r.json()
        images = json["images"]
        images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
        version = json.get("version", "unknown")
        return {"images": images, "version": version}
    else:
        raise ServiceError(r.status_code)


def get_model_version(url):
    r = requests.get(url)
    if r.status_code == 200:
        version = r.json()["version"]
        return version
    else:
        raise ServiceError(r.status_code)


================================================
FILE: app/streamlit/app.py
================================================
#!/usr/bin/env python
# coding: utf-8

import streamlit as st
from backend import ServiceError, get_images_from_backend

st.sidebar.markdown(
    """
<style>
.aligncenter {
    text-align: center;
}
</style>
<p class="aligncenter">
    <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/img/logo.png"/>
</p>
""",
    unsafe_allow_html=True,
)
st.sidebar.markdown(
    """
___
<p style='text-align: center'>
DALL·E mini is an AI model that generates images from any prompt you give!
</p>

<p style='text-align: center'>
Created by Boris Dayma et al. 2021-2022
<br/>
<a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA" target="_blank">Project Report</a>
</p>
        """,
    unsafe_allow_html=True,
)

st.header("DALL·E mini")
st.subheader("Generate images from text")

prompt = st.text_input("What do you want to see?")

DEBUG = False
if prompt != "":
    container = st.empty()
    container.markdown(
        f"""
        <style> p {{ margin:0 }} div {{ margin:0 }} </style>
        <div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
        <div class="stAlert">
        <div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
        <div class="st-b7">
        <div class="css-whx05o e13vu3m50">
        <div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
                <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
                Generating predictions for: <b>{prompt}</b>
        </div>
        </div>
        </div>
        </div>
        </div>
        </div>
        <small><i>Predictions may take up to 5mn under high load. Please stand by.</i></small>
    """,
        unsafe_allow_html=True,
    )

    try:
        backend_url = st.secrets["BACKEND_SERVER"] + "/generate"
        response = get_images_from_backend(prompt, backend_url)
        selected = response["images"]
        version = response["version"]

        margin = 0.1  # for better position of zoom in arrow
        n_columns = 3
        cols = st.columns([1] + [margin, 1] * (n_columns - 1))
        for i, img in enumerate(selected):
            cols[(i % n_columns) * 2].image(img)
        container.markdown(f"**{prompt}**")

        # st.sidebar.markdown(
        #    f"<small><center>{version}</center></small>", unsafe_allow_html=True
        # )

        # st.markdown(
        #    f"""
        # These results have been obtained using model `{version}` from [an ongoing training run](https://wandb.ai/dalle-mini/dalle-mini/runs/mheh9e55).
        # """
        # )

        st.button("Again!", key="again_button")

    except ServiceError as error:
        container.text(f"Service unavailable, status: {error.status_code}")
    except KeyError:
        if DEBUG:
            container.markdown(
                """
            **Error: BACKEND_SERVER unset**

            Please, create a file called `.streamlit/secrets.toml` inside the app's folder and include a line to configure the server URL:
            ```
            BACKEND_SERVER="<server url>"
            ```
            """
            )
        else:
            container.markdown(
                "Error -5, please try again or [report it](mailto:pcuenca-dalle@guenever.net)."
            )


================================================
FILE: app/streamlit/backend.py
================================================
# Client requests to Dalle-Mini Backend server

import base64
from io import BytesIO

import requests
from PIL import Image


class ServiceError(Exception):
    def __init__(self, status_code):
        self.status_code = status_code


def get_images_from_backend(prompt, backend_url):
    r = requests.post(backend_url, json={"prompt": prompt})
    if r.status_code == 200:
        json = r.json()
        images = json["images"]
        images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
        version = json.get("version", "unknown")
        return {"images": images, "version": version}
    else:
        raise ServiceError(r.status_code)


def get_model_version(url):
    r = requests.get(url)
    if r.status_code == 200:
        version = r.json()["version"]
        return version
    else:
        raise ServiceError(r.status_code)


================================================
FILE: pyproject.toml
================================================
[tool.isort]
profile = "black"

================================================
FILE: run_docker_image.sh
================================================
#!/bin/bash

# This script is used to run the docker image. Change or remove GPU flag if you dont have nvidia-docker or the needed GPUs
docker run --rm --name dallemini -it -p 8888:8888  --gpus all  -v "${PWD}":/workspace dalle-mini:latest


================================================
FILE: setup.cfg
================================================
[metadata]
name = dalle-mini
version = attr: dalle_mini.__version__
author = Boris Dayma et al.
author_email = boris.dayma@gmail.com
description = DALL·E mini - Generate images from a text prompt
long_description = file: README.md
long_description_content_type = text/markdown
url = https://github.com/borisdayma/dalle-mini
project_urls =
    Bug Tracker = https://github.com/borisdayma/dalle-mini/issues
classifiers =
    Programming Language :: Python :: 3
    License :: OSI Approved :: Apache Software License
    Operating System :: OS Independent
    Topic :: Scientific/Engineering :: Artificial Intelligence
    Development Status :: 3 - Alpha
    Intended Audience :: Developers

[options]
package_dir =
    =src
packages = find:
python_requires = >=3.6
install_requires =
    transformers==4.25.1
    einops
    unidecode
    ftfy
    emoji
    pillow
    jax==0.3.25
    flax==0.6.3
    orbax==0.0.23
    wandb

[options.extras_require]
dev =
    tqdm
    optax
    braceexpand
    datasets[streaming]
    black[jupyter]
    isort

[options.packages.find]
where = src


================================================
FILE: setup.py
================================================
from setuptools import setup

if __name__ == "__main__":
    setup()


================================================
FILE: src/dalle_mini/__init__.py
================================================
__version__ = "0.1.5"

from .model import DalleBart, DalleBartProcessor


================================================
FILE: src/dalle_mini/data.py
================================================
import random
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path

import jax
import jax.numpy as jnp
import numpy as np
from braceexpand import braceexpand
from datasets import Dataset, load_dataset

from .model.text import TextNormalizer


@dataclass
class Dataset:
    dataset_repo_or_path: str
    train_file: str = None
    validation_file: str = None
    streaming: bool = True
    use_auth_token: bool = False
    text_column: str = "caption"
    encoding_column: str = "encoding"
    max_train_samples: int = None
    max_eval_samples: int = None
    preprocessing_num_workers: int = None
    overwrite_cache: bool = False
    do_train: bool = False
    do_eval: bool = True
    seed_dataset: int = None
    shard_by_host: bool = False
    blank_caption_prob: float = 0.0
    clip_score_column: str = "clip_score"
    min_clip_score: float = None
    max_clip_score: float = None
    filter_column: str = None
    filter_value: str = None
    multi_eval_ds: bool = False
    train_dataset: Dataset = field(init=False)
    eval_dataset: Dataset = field(init=False)
    other_eval_datasets: list = field(init=False)
    rng_dataset: jnp.ndarray = field(init=False)
    multi_hosts: bool = field(init=False)

    def __post_init__(self):
        if self.seed_dataset is None:
            # create a random seed
            self.seed_dataset = random.randint(0, 2**32 - 1)
        # set numpy rng
        self.np_rng = np.random.default_rng(self.seed_dataset)
        self.multi_hosts = jax.process_count() > 1
        # feed blank captions only in streaming mode for now
        # otherwise dataset could be cached with same blanked captions
        if self.blank_caption_prob:
            assert (
                self.streaming is True
            ), "blank_caption_prob can only be used in streaming mode"
        # define data_files
        if self.train_file is not None or self.validation_file is not None:
            # accept braceexpand notation
            for k in ["train_file", "validation_file"]:
                f = getattr(self, k)
                if isinstance(f, str):
                    setattr(self, k, list(braceexpand(f)))
            # for list of files, split training data shards by host
            if (
                isinstance(self.train_file, list)
                and self.multi_hosts
                and self.shard_by_host
            ):
                self.train_file = self.train_file[
                    jax.process_index() :: jax.process_count()
                ]
            data_files = {
                "train": self.train_file,
                "validation": self.validation_file,
            }
        else:
            data_files = None

        # multiple validation datasets
        if self.multi_eval_ds:
            assert Path(
                self.dataset_repo_or_path
            ).is_dir(), f"{self.dataset_repo_or_path} is not a directory, required for multi_eval_ds"
            data_files = {
                split.name: [str(f) for f in split.glob("*.parquet")]
                for split in Path(self.dataset_repo_or_path).glob("*")
            }
            # rename "valid" to "validation" if present for consistency
            if "valid" in data_files:
                data_files["validation"] = data_files["valid"]
                del data_files["valid"]
            self.dataset_repo_or_path = "parquet"

        # load dataset
        dataset = load_dataset(
            self.dataset_repo_or_path,
            data_files=data_files,
            streaming=self.streaming,
            use_auth_token=self.use_auth_token,
        )
        if self.do_train:
            if "train" not in dataset:
                raise ValueError("Training requires a training dataset")
            self.train_dataset = dataset["train"]
            if self.max_train_samples is not None:
                self.train_dataset = (
                    self.train_dataset.take(self.max_train_samples)
                    if self.streaming
                    else self.train_dataset.select(range(self.max_train_samples))
                )
        if self.do_eval:
            if "validation" not in dataset:
                raise ValueError("Evaluating requires a validation dataset")
            self.eval_dataset = dataset["validation"]
            if self.max_eval_samples is not None:
                self.eval_dataset = (
                    self.eval_dataset.take(self.max_eval_samples)
                    if self.streaming
                    else self.eval_dataset.select(range(self.max_eval_samples))
                )
            # other eval datasets
            other_eval_splits = dataset.keys() - {"train", "validation"}
            self.other_eval_datasets = {
                split: dataset[split] for split in other_eval_splits
            }

    def preprocess(self, tokenizer, config):
        # get required config variables
        decoder_start_token_id = config.decoder_start_token_id
        normalize_text = config.normalize_text
        max_length = config.max_text_length

        if self.streaming:
            # we need to shuffle early in streaming mode
            if hasattr(self, "train_dataset"):
                self.train_dataset = self.train_dataset.shuffle(
                    buffer_size=5000, seed=self.seed_dataset
                )
        else:
            self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)

        # filter data
        partial_filter_function = partial(
            filter_function,
            filter_column=self.filter_column,
            filter_value=self.filter_value,
            clip_score_column=self.clip_score_column,
            min_clip_score=self.min_clip_score,
            max_clip_score=self.max_clip_score,
        )
        for ds in ["train_dataset", "eval_dataset"]:
            if hasattr(self, ds):
                setattr(
                    self,
                    ds,
                    (
                        getattr(self, ds).filter(partial_filter_function)
                        if self.streaming
                        else getattr(self, ds).filter(
                            partial_filter_function,
                            num_proc=self.preprocessing_num_workers,
                            load_from_cache_file=not self.overwrite_cache,
                            desc="Filtering datasets",
                        )
                    ),
                )
        if hasattr(self, "other_eval_datasets"):
            self.other_eval_datasets = {
                split: (
                    ds.filter(partial_filter_function)
                    if self.streaming
                    else ds.filter(
                        partial_filter_function,
                        num_proc=self.preprocessing_num_workers,
                        load_from_cache_file=not self.overwrite_cache,
                        desc="Filtering datasets",
                    )
                )
                for split, ds in self.other_eval_datasets.items()
            }

        # normalize text
        if normalize_text:
            text_normalizer = TextNormalizer()
            partial_normalize_function = partial(
                normalize_function,
                text_column=self.text_column,
                text_normalizer=text_normalizer,
            )
            for ds in ["train_dataset", "eval_dataset"]:
                if hasattr(self, ds):
                    setattr(
                        self,
                        ds,
                        (
                            getattr(self, ds).map(partial_normalize_function)
                            if self.streaming
                            else getattr(self, ds).map(
                                partial_normalize_function,
                                num_proc=self.preprocessing_num_workers,
                                load_from_cache_file=not self.overwrite_cache,
                                desc="Normalizing datasets",
                            )
                        ),
                    )
            if hasattr(self, "other_eval_datasets"):
                self.other_eval_datasets = {
                    split: (
                        ds.map(partial_normalize_function)
                        if self.streaming
                        else ds.map(
                            partial_normalize_function,
                            num_proc=self.preprocessing_num_workers,
                            load_from_cache_file=not self.overwrite_cache,
                            desc="Normalizing datasets",
                        )
                    )
                    for split, ds in self.other_eval_datasets.items()
                }

        # blank captions
        if self.blank_caption_prob:
            partial_blank_caption_function = partial(
                blank_caption_function,
                text_column=self.text_column,
                blank_caption_prob=self.blank_caption_prob,
                rng=self.np_rng,
            )
            if hasattr(self, "train_dataset"):
                self.train_dataset = (
                    self.train_dataset.map(partial_blank_caption_function)
                    if self.streaming
                    else self.train_dataset.map(
                        partial_blank_caption_function,
                        num_proc=None
                        if self.seed_dataset
                        else self.preprocessing_num_workers,
                        load_from_cache_file=False,
                        desc="Blanking some captions",
                    )
                )

        # preprocess
        partial_preprocess_function = partial(
            preprocess_function,
            tokenizer=tokenizer,
            text_column=self.text_column,
            encoding_column=self.encoding_column,
            max_length=max_length,
            decoder_start_token_id=decoder_start_token_id,
        )
        for ds in ["train_dataset", "eval_dataset"]:
            if hasattr(self, ds):
                setattr(
                    self,
                    ds,
                    (
                        getattr(self, ds).map(
                            partial_preprocess_function,
                            batched=True,
                            remove_columns=[
                                self.text_column,
                                self.encoding_column,
                            ],
                        )
                        if self.streaming
                        else getattr(self, ds).map(
                            partial_preprocess_function,
                            batched=True,
                            remove_columns=getattr(ds, "column_names"),
                            num_proc=self.preprocessing_num_workers,
                            load_from_cache_file=not self.overwrite_cache,
                            desc="Preprocessing datasets",
                        )
                    ),
                )
        if hasattr(self, "other_eval_datasets"):
            self.other_eval_datasets = {
                split: (
                    ds.map(
                        partial_preprocess_function,
                        batched=True,
                        remove_columns=[
                            self.text_column,
                            self.encoding_column,
                        ],
                    )
                    if self.streaming
                    else ds.map(
                        partial_preprocess_function,
                        batched=True,
                        remove_columns=getattr(ds, "column_names"),
                        num_proc=self.preprocessing_num_workers,
                        load_from_cache_file=not self.overwrite_cache,
                        desc="Preprocessing datasets",
                    )
                )
                for split, ds in self.other_eval_datasets.items()
            }

    def dataloader(self, split, batch_size, epoch=None):
        def _dataloader_datasets_non_streaming(
            dataset: Dataset,
            rng: jax.random.PRNGKey = None,
        ):
            """
            Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
            Shuffle batches if rng is set.
            """
            steps_per_epoch = len(dataset) // batch_size

            if rng is not None:
                batch_idx = jax.random.permutation(rng, len(dataset))
            else:
                batch_idx = jnp.arange(len(dataset))

            batch_idx = batch_idx[
                : steps_per_epoch * batch_size
            ]  # Skip incomplete batch.
            batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))

            for idx in batch_idx:
                batch = dataset[idx]
                batch = {k: jnp.array(v) for k, v in batch.items()}
                yield batch

        def _dataloader_datasets_streaming(
            dataset: Dataset,
            epoch: int,
        ):
            keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
            batch = {k: [] for k in keys}
            first_loop = True  # stop after one loop in some cases
            while (self.multi_hosts and split == "train") or first_loop:
                # in multi-host, we run forever (no epoch) as hosts need to stop
                # at the same time and training data may not be split equally
                # For validation data we put the entire batch on each host and then
                # keep only the one specific to each host (could be improved but not necessary)
                if epoch is not None:
                    assert split == "train"
                    # reshuffle training data at each epoch
                    dataset.set_epoch(epoch)
                    epoch += 1
                for item in dataset:
                    for k in keys:
                        batch[k].append(item[k])
                    if len(batch[keys[0]]) == batch_size:
                        batch = {k: jnp.array(v) for k, v in batch.items()}
                        yield batch
                        batch = {k: [] for k in keys}
                first_loop = False

        if split == "train":
            ds = self.train_dataset
        elif split == "eval":
            ds = self.eval_dataset
        else:
            ds = self.other_eval_datasets[split]

        if self.streaming:
            return _dataloader_datasets_streaming(ds, epoch)
        else:
            if split == "train":
                self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
            return _dataloader_datasets_non_streaming(ds, input_rng)

    @property
    def length(self):
        len_train_dataset, len_eval_dataset = None, None
        if self.streaming:
            # we don't know the length, let's just assume max_samples if defined
            if self.max_train_samples is not None:
                len_train_dataset = self.max_train_samples
            if self.max_eval_samples is not None:
                len_eval_dataset = self.max_eval_samples
        else:
            len_train_dataset = (
                len(self.train_dataset) if hasattr(self, "train_dataset") else None
            )
            len_eval_dataset = (
                len(self.eval_dataset) if hasattr(self, "eval_dataset") else None
            )
        return len_train_dataset, len_eval_dataset


def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = np.zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1]
    shifted_input_ids[:, 0] = decoder_start_token_id
    return shifted_input_ids


def blank_caption_function(example, text_column, blank_caption_prob, rng=None):
    if (
        blank_caption_prob
        and (rng.random() if rng is not None else np.random.random())
        < blank_caption_prob
    ):
        example[text_column] = ""
    return example


def normalize_function(example, text_column, text_normalizer):
    example[text_column] = text_normalizer(example[text_column])
    return example


def filter_function(
    example,
    min_clip_score,
    max_clip_score,
    clip_score_column,
    filter_column,
    filter_value,
):
    if min_clip_score is not None and example[clip_score_column] < min_clip_score:
        return False
    if max_clip_score is not None and example[clip_score_column] > max_clip_score:
        return False
    if filter_column is not None and example[filter_column] != filter_value:
        return False
    return True


def preprocess_function(
    examples,
    tokenizer,
    text_column,
    encoding_column,
    max_length,
    decoder_start_token_id,
):
    inputs = examples[text_column]
    # Setting padding="max_length" as we need fixed length inputs for jitted functions
    model_inputs = tokenizer(
        inputs,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="np",
    )

    # set up targets
    # Note: labels correspond to our target indices
    # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
    labels = examples[encoding_column]
    labels = np.asarray(labels)

    # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
    model_inputs["labels"] = labels

    # In our case, this prepends the bos token and removes the last one
    decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id)
    model_inputs["decoder_input_ids"] = decoder_input_ids

    return model_inputs


================================================
FILE: src/dalle_mini/model/__init__.py
================================================
from .configuration import DalleBartConfig
from .modeling import DalleBart
from .partitions import set_partitions
from .processor import DalleBartProcessor
from .tokenizer import DalleBartTokenizer


================================================
FILE: src/dalle_mini/model/configuration.py
================================================
# coding=utf-8
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" DalleBart model configuration """
import warnings

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

from .utils import PretrainedFromWandbMixin

logger = logging.get_logger(__name__)


class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
    model_type = "dallebart"
    keys_to_ignore_at_inference = ["past_key_values"]
    attribute_map = {
        "num_attention_heads": "encoder_attention_heads",
        "hidden_size": "d_model",
    }

    def __init__(
        self,
        normalize_text=False,
        encoder_vocab_size=50264,
        image_vocab_size=16384,  # encoded image token space
        image_length=256,  # number of encoded tokens
        max_text_length=64,  # max number of text tokens
        encoder_layers=12,
        encoder_ffn_dim=4096,
        encoder_attention_heads=16,
        decoder_layers=12,
        decoder_ffn_dim=4096,
        decoder_attention_heads=16,
        activation_function="gelu",
        d_model=1024,
        dropout=0.1,
        attention_dropout=0.0,
        activation_dropout=0.0,
        init_std=0.02,
        scale_embedding=False,
        gradient_checkpointing=True,
        use_scan=None,
        use_cache=True,
        is_encoder_decoder=True,
        forced_eos_token_id=None,
        tie_word_embeddings=False,  # different modalities and sizes
        do_sample=True,
        # transformer variants
        use_bias=False,  # use bias in attention and dense layers (except for lm_head)
        ln_type="layernorm",  # layer normalization type, "rmsnorm", "layernorm"
        ln_positions="normformer",  # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln), "subln"
        use_head_scale=False,  # used in NormFormer
        use_cosine_attention=False,  # used in Swin v2
        tau_init=0.05,  # used only in cosine attention (Swin v2)
        use_absolute_position_embeddings=True,  # default
        use_swin_position_embeddings=False,  # used in Swin v1/v2
        use_deepnet_scaling=False,  # used in Deepnet
        use_subln_init=False,
        use_glu=True,  # "GLU Variants Improve Transformer"
        use_alibi=False,  # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
        sinkhorn_iters=1,  # used in SinkFormers
        use_final_ln_encoder=True,  # final layer normalization in encoder
        use_final_ln_decoder=True,  # final layer normalization in decoder
        # parameters that should not be necessary but could affect results
        force_ln_scale=False,  # force scale in layernorm even when followed by dense layers
        **kwargs,
    ):
        # text normalizer
        self.normalize_text = normalize_text

        # transformer variants
        self.use_bias = use_bias
        assert ln_type in [
            "rmsnorm",
            "layernorm",
        ], "ln_type must be 'rmsnorm' or 'layernorm'"
        self.ln_type = ln_type
        if ln_positions == "deepnet":
            ln_positions = "postln"
        assert ln_positions in [
            "normformer",
            "swinv2",
            "cogview",
            "postln",
            "preln",
            "subln",
        ], "ln_positions must be 'normformer', 'swinv2', 'cogview', 'postln', 'preln', 'subln'"
        self.use_head_scale = use_head_scale
        assert use_alibi is False, "use_alibi is not supported yet"
        self.ln_positions = ln_positions
        self.use_cosine_attention = use_cosine_attention
        self.tau_init = tau_init
        self.use_absolute_position_embeddings = use_absolute_position_embeddings
        self.use_swin_position_embeddings = use_swin_position_embeddings
        self.use_deepnet_scaling = use_deepnet_scaling
        self.use_subln_init = use_subln_init
        self.use_glu = use_glu
        self.use_alibi = use_alibi
        self.sinkhorn_iters = sinkhorn_iters
        if ln_positions == "postln":
            assert (
                use_final_ln_encoder
            ), "use_final_ln_encoder must be True when ln_positions is 'postln'"
            assert (
                use_final_ln_decoder
            ), "use_final_ln_decoder must be True when ln_positions is 'postln'"
        self.use_final_ln_encoder = use_final_ln_encoder
        self.use_final_ln_decoder = use_final_ln_decoder
        self.force_ln_scale = force_ln_scale

        # common parameters
        self.encoder_vocab_size = encoder_vocab_size
        self.image_vocab_size = image_vocab_size
        self.image_length = image_length
        self.max_text_length = max_text_length
        self.d_model = d_model
        self.encoder_ffn_dim = encoder_ffn_dim
        self.encoder_layers = encoder_layers
        self.encoder_attention_heads = encoder_attention_heads
        self.decoder_ffn_dim = decoder_ffn_dim
        self.decoder_layers = decoder_layers
        self.decoder_attention_heads = decoder_attention_heads
        self.dropout = dropout
        self.attention_dropout = attention_dropout
        self.activation_dropout = activation_dropout
        self.activation_function = activation_function
        self.init_std = init_std
        self.use_cache = use_cache
        self.gradient_checkpointing = gradient_checkpointing
        # all layers are the same in most configurations
        self.use_scan = use_scan if use_scan is not None else ln_positions != "swinv2"
        assert not (
            self.use_scan and ln_positions == "swinv2"
        ), "scan cannot be used with 'swinv2'"
        self.scale_embedding = (
            scale_embedding  # scale factor will be sqrt(d_model) if True
        )

        # special token id's are appended to vocab if not provided
        decoder_start_token_id = kwargs.pop("decoder_start_token_id", image_vocab_size)
        bos_token_id = kwargs.pop("bos_token_id", image_vocab_size)
        pad_token_id = kwargs.pop("pad_token_id", image_vocab_size)
        eos_token_id = kwargs.pop("eos_token_id", image_vocab_size)

        # we generate to image_length + 1 (for bos) by default
        min_length = kwargs.pop("min_length", image_length + 1)
        max_length = kwargs.pop("max_length", image_length + 1)

        super().__init__(
            # args required in parent class
            is_encoder_decoder=is_encoder_decoder,
            tie_word_embeddings=tie_word_embeddings,
            forced_eos_token_id=forced_eos_token_id,
            decoder_start_token_id=decoder_start_token_id,
            bos_token_id=bos_token_id,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            min_length=min_length,
            max_length=max_length,
            do_sample=do_sample,
            **kwargs,
        )

        # ensure backward compatibility for BART CNN models
        if self.forced_bos_token_id is None and kwargs.get(
            "force_bos_token_to_be_generated", False
        ):
            self.forced_bos_token_id = self.bos_token_id
            warnings.warn(
                f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions."
                "The config can simply be saved and uploaded again to be fixed."
            )


================================================
FILE: src/dalle_mini/model/modeling.py
================================================
# coding=utf-8
# Copyright 2021-2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team and & DALL·E Mini team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" DalleBart model. """

import math
from functools import partial
from typing import Any, Dict, Optional, Tuple

import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from einops import rearrange
from flax.core.frozen_dict import unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.linear import PrecisionLike
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import custom_jvp, lax
from jax.random import PRNGKey
from transformers.modeling_flax_outputs import (
    FlaxBaseModelOutput,
    FlaxBaseModelOutputWithPastAndCrossAttentions,
    FlaxCausalLMOutputWithCrossAttentions,
    FlaxSeq2SeqLMOutput,
)
from transformers.modeling_flax_utils import ACT2FN
from transformers.models.bart.modeling_flax_bart import (
    FlaxBartAttention,
    FlaxBartForConditionalGeneration,
    FlaxBartForConditionalGenerationModule,
    FlaxBartModule,
)
from transformers.utils import ModelOutput, logging

from .configuration import DalleBartConfig
from .utils import PretrainedFromWandbMixin

logger = logging.get_logger(__name__)

remat = nn_partitioning.remat


def smelu(beta: Any = 1.0):
    """
    Implementation of "Real World Large Scale Recommendation Systems Reproducibility and Smooth Activations"
    https://arxiv.org/abs/2202.06499
    """

    @custom_jvp
    @jax.jit
    def _smelu(x: Any) -> Any:
        x = jnp.where(x <= -beta, 0.0, x)
        return jnp.where(x >= beta, x, jnp.square(x + beta) / (4 * beta))

    _smelu.defjvps(
        lambda g, ans, x: lax.select(
            x == -beta,
            lax.full_like(g, 0),
            lax.select(x == beta, lax.full_like(g, 1), g),
        )
    )
    return _smelu


ACT2FN.update({"smelu": smelu()})


# deepnet initialization
def deepnet_init(init_std, gain=1):
    init = jax.nn.initializers.normal(init_std)

    def _init(*args, **kwargs):
        return gain * init(*args, **kwargs)

    return _init


# deepnet gain
deepnet_gain = {
    "encoder": {
        "alpha": lambda config: 0.81
        * (config.encoder_layers**4 * config.decoder_layers) ** 0.0625,
        "beta": lambda config: 0.87
        * (config.encoder_layers**4 * config.decoder_layers) ** -0.0625,
    },
    "decoder": {
        "alpha": lambda config: (3 * config.decoder_layers) ** 0.25,
        "beta": lambda config: (12 * config.decoder_layers) ** -0.25,
    },
}

# subln gain
subln_gain = {
    "encoder": lambda config: math.sqrt(
        1.0
        / 3.0
        * math.log(3 * config.decoder_layers)
        * math.log(2 * config.encoder_layers)
    ),
    "decoder": lambda config: math.sqrt(math.log(3 * config.decoder_layers)),
}


class RMSNorm(nn.Module):
    """
    From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467

    Adapted from flax.linen.LayerNorm
    """

    epsilon: float = 1e-6
    dtype: Any = jnp.float32
    param_dtype: Any = jnp.float32
    use_scale: bool = True
    scale_init: Any = jax.nn.initializers.ones

    @nn.compact
    def __call__(self, x):
        reduction_axes = (-1,)
        feature_axes = (-1,)

        rms_sq = self._compute_rms_sq(x, reduction_axes)

        return self._normalize(
            self,
            x,
            rms_sq,
            reduction_axes,
            feature_axes,
            self.dtype,
            self.param_dtype,
            self.epsilon,
            self.use_scale,
            self.scale_init,
        )

    def _compute_rms_sq(self, x, axes):
        x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
        rms_sq = jnp.mean(jax.lax.square(x), axes)
        return rms_sq

    def _normalize(
        self,
        mdl,
        x,
        rms_sq,
        reduction_axes,
        feature_axes,
        dtype,
        param_dtype,
        epsilon,
        use_scale,
        scale_init,
    ):
        reduction_axes = nn.normalization._canonicalize_axes(x.ndim, reduction_axes)
        feature_axes = nn.normalization._canonicalize_axes(x.ndim, feature_axes)
        stats_shape = list(x.shape)
        for axis in reduction_axes:
            stats_shape[axis] = 1
        rms_sq = rms_sq.reshape(stats_shape)
        feature_shape = [1] * x.ndim
        reduced_feature_shape = []
        for ax in feature_axes:
            feature_shape[ax] = x.shape[ax]
            reduced_feature_shape.append(x.shape[ax])
        mul = lax.rsqrt(rms_sq + epsilon)
        if use_scale:
            scale = mdl.param(
                "scale", scale_init, reduced_feature_shape, param_dtype
            ).reshape(feature_shape)
            mul *= scale
        y = mul * x
        return jnp.asarray(y, dtype)


def norm(type, *args, **kwargs):
    if type == "rmsnorm":
        return RMSNorm(*args, **kwargs)
    elif type == "layernorm":
        return nn.LayerNorm(*args, **kwargs)
    else:
        raise ValueError(f"Unknown norm type {type}")


def dot_product_attention_weights(
    query: Any,
    key: Any,
    bias: Optional[Any] = None,
    mask: Optional[Any] = None,
    embed_pos: Optional[Any] = None,
    broadcast_dropout: bool = True,
    dropout_rng: Optional[PRNGKey] = None,
    dropout_rate: float = 0.0,
    deterministic: bool = False,
    dtype: Any = jnp.float32,
    precision: PrecisionLike = None,
    sinkhorn_iters: int = 1,
    is_encoder: bool = False,
    tau=None,
):
    """
    Computes dot-product attention weights given query and key.
    mask is included into the bias.

    Adapted from flax.linen.attention.dot_product_attention_weights"
    """
    assert query.ndim == key.ndim, "q, k must have same rank."
    assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match."
    assert query.shape[-2] == key.shape[-2], "q, k num_heads must match."
    assert query.shape[-1] == key.shape[-1], "q, k depths must match."

    # attn weight shape is (batch..., num_heads, q_length, kv_length)
    attn_weights = jnp.einsum("...qhd,...khd->...hqk", query, key, precision=precision)

    # divide by tau (used in Swin v2)
    if tau is not None:
        attn_weights = attn_weights / tau
    else:
        depth = query.shape[-1]
        attn_weights = attn_weights / jnp.sqrt(depth).astype(dtype)

    # apply attention bias: masking, dropout, proximity bias, etc.
    if bias is not None:
        attn_weights = attn_weights + bias

    # add relative position
    if embed_pos is not None:
        attn_weights = attn_weights + embed_pos

    # normalize the attention weights
    if not is_encoder or sinkhorn_iters == 1:
        # sinkhorn does not work for causal (leaks info of future tokens into past)
        attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
    else:
        # adapted from https://github.com/lucidrains/sinkhorn-transformer
        for i in range(sinkhorn_iters):
            # when causal, some attn_weights have been set to -inf through bias
            if i % 2 == 0:
                attn_weights -= jax.nn.logsumexp(attn_weights, axis=-1, keepdims=True)
            else:
                attn_weights -= jax.nn.logsumexp(attn_weights, axis=-2, keepdims=True)
            if mask is not None:
                attn_weights = jnp.where(mask, attn_weights, -jnp.inf)
        attn_weights = jnp.exp(attn_weights).astype(dtype)

    # apply attention dropout
    if not deterministic and dropout_rate > 0.0:
        keep_prob = 1.0 - dropout_rate
        if broadcast_dropout:
            # dropout is broadcast across the batch + head dimensions
            dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
            keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape)
        else:
            keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
        multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(
            keep_prob, dtype=dtype
        )
        attn_weights = attn_weights * multiplier

    return attn_weights


class FlaxBartAttention(FlaxBartAttention):
    """
    Edits:
    - causal mask is used only in decoder and considers image_length
    - scale attention heads per NormFormer paper
    """

    is_encoder: bool = False
    is_cross_attention: bool = False
    q_length: int = None
    k_length: int = None

    def setup(self) -> None:
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {self.num_heads})."
            )

        dense = partial(
            nn.Dense,
            self.embed_dim,
            use_bias=self.bias,
            dtype=self.dtype,
        )

        if self.config.use_deepnet_scaling:
            gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
                self.config
            )
        elif self.config.use_subln_init and not self.is_cross_attention:
            gain = subln_gain["encoder" if self.is_encoder else "decoder"](self.config)

        self.q_proj = dense(
            kernel_init=jax.nn.initializers.normal(self.config.init_std)
        )
        self.k_proj = dense(
            kernel_init=jax.nn.initializers.normal(self.config.init_std)
        )
        self.v_proj = dense(
            kernel_init=deepnet_init(self.config.init_std, gain)
            if (
                self.config.use_deepnet_scaling
                or (self.config.use_subln_init and not self.is_cross_attention)
            )
            else jax.nn.initializers.normal(self.config.init_std)
        )
        self.out_proj = dense(
            kernel_init=deepnet_init(self.config.init_std, gain)
            if (
                self.config.use_deepnet_scaling
                or (self.config.use_subln_init and not self.is_cross_attention)
            )
            else jax.nn.initializers.normal(self.config.init_std)
        )
        self.dropout_layer = nn.Dropout(rate=self.dropout)

        if self.config.use_head_scale:
            self.head_scale = self.param(
                "head_scale", jax.nn.initializers.ones, (1, 1, self.num_heads, 1)
            )

        if self.config.use_cosine_attention:
            # TODO: try using a learnt scale, somehow it immediately diverges in my experiments
            self.tau = self.config.tau_init

        if self.config.use_swin_position_embeddings:
            self.rel_bias = nn.Embed(
                self.q_length,
                self.k_length * self.num_heads,
                embedding_init=jax.nn.initializers.normal(self.config.init_std),
            )

        if self.causal:
            # used only in decoder
            self.causal_mask = make_causal_mask(
                jnp.ones((1, self.config.image_length), dtype="bool"), dtype="bool"
            )

        if self.config.ln_positions in ["subln"] and not self.is_cross_attention:
            self.mid_layernorm = norm(
                self.config.ln_type, dtype=self.dtype, epsilon=1e-05
            )

    def __call__(
        self,
        hidden_states: jnp.ndarray,
        key_value_states: Optional[jnp.ndarray] = None,
        attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
    ) -> Tuple[jnp.ndarray]:
        """Input shape: Batch x Time x Channel"""

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None
        batch_size = hidden_states.shape[0]

        # get query proj
        query_states = self.q_proj(hidden_states)
        # get key, value proj
        if is_cross_attention:
            # cross_attentions
            key_states = self.k_proj(key_value_states)
            value_states = self.v_proj(key_value_states)
        else:
            # self_attention
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = self._split_heads(query_states)
        key_states = self._split_heads(key_states)
        value_states = self._split_heads(value_states)

        # handle cache prepare causal attention mask
        if self.causal:
            query_length, key_length = query_states.shape[1], key_states.shape[1]
            if self.has_variable("cache", "cached_key"):
                mask_shift = self.variables["cache"]["cache_index"]
                max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
                causal_mask = lax.dynamic_slice(
                    self.causal_mask,
                    (0, 0, mask_shift, 0),
                    (1, 1, query_length, max_decoder_length),
                )
            else:
                causal_mask = self.causal_mask[:, :, :query_length, :key_length]
            causal_mask = jnp.broadcast_to(
                causal_mask, (batch_size,) + causal_mask.shape[1:]
            )

        # combine masks if needed
        if attention_mask is not None and self.causal:
            attention_mask = jnp.broadcast_to(
                jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape
            )
            attention_mask = combine_masks(attention_mask, causal_mask)
        elif self.causal:
            attention_mask = causal_mask
        elif attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
            key_states, value_states, attention_mask = self._concatenate_to_cache(
                key_states, value_states, query_states, attention_mask
            )

        # Convert the boolean attention mask to an attention bias.
        if attention_mask is not None:
            # attention mask in the form of attention bias
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, -jnp.inf).astype(self.dtype),
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        if self.config.use_cosine_attention:
            # normalize q and k
            query_states = query_states / (
                jnp.linalg.norm(query_states, axis=-1, keepdims=True) + 1e-8
            )
            key_states = key_states / (
                jnp.linalg.norm(key_states, axis=-1, keepdims=True) + 1e-8
            )

        # relative position embeddings
        if self.config.use_swin_position_embeddings:
            position_ids = jnp.arange(self.q_length)
            embed_pos = self.rel_bias(position_ids)
            embed_pos = rearrange(embed_pos, "q (k h) -> 1 h q k", h=self.num_heads)
        else:
            embed_pos = None

        tau = self.tau if self.config.use_cosine_attention else None
        attn_weights = dot_product_attention_weights(
            query_states,
            key_states,
            bias=attention_bias,
            mask=attention_mask,
            embed_pos=embed_pos,
            dropout_rng=dropout_rng,
            dropout_rate=self.dropout,
            broadcast_dropout=True,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
            sinkhorn_iters=self.config.sinkhorn_iters,
            is_encoder=self.is_encoder,
            tau=tau,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
        if self.config.use_head_scale:
            # per Normformer
            attn_output = attn_output * self.head_scale
        attn_output = self._merge_heads(attn_output)

        if self.config.ln_positions in ["subln"] and not self.is_cross_attention:
            attn_output = self.mid_layernorm(attn_output)

        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights


class GLU(nn.Module):
    """From "GLU Variants Improve Transformer" by https://arxiv.org/abs/2002.05202"""

    config: DalleBartConfig
    ffn_dim: int
    embed_dim: int
    dtype: jnp.dtype = jnp.float32
    is_encoder: bool = False

    @nn.compact
    def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
        if self.config.use_deepnet_scaling:
            gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
                self.config
            )
        elif self.config.use_subln_init:
            gain = subln_gain["encoder" if self.is_encoder else "decoder"](self.config)

        if self.config.ln_positions in ["normformer", "cogview", "preln", "subln"]:
            x = norm(
                self.config.ln_type,
                dtype=self.dtype,
                epsilon=1e-05,
                use_scale=self.config.force_ln_scale,
            )(x)
        w = nn.Dense(
            self.ffn_dim,
            dtype=self.dtype,
            use_bias=self.config.use_bias,
            kernel_init=deepnet_init(self.config.init_std, gain)
            if (self.config.use_deepnet_scaling or self.config.use_subln_init)
            else jax.nn.initializers.normal(self.config.init_std),
        )(x)
        w = ACT2FN[self.config.activation_function](w)
        v = nn.Dense(
            self.ffn_dim,
            dtype=self.dtype,
            use_bias=self.config.use_bias,
            kernel_init=deepnet_init(self.config.init_std, gain)
            if (self.config.use_deepnet_scaling or self.config.use_subln_init)
            else jax.nn.initializers.normal(self.config.init_std),
        )(x)
        x = w * v
        if self.config.ln_positions in ["normformer", "subln"]:
            x = norm(
                self.config.ln_type,
                dtype=self.dtype,
                epsilon=1e-05,
                use_scale=self.config.force_ln_scale,
            )(x)
        x = nn.Dropout(rate=self.config.activation_dropout)(
            x, deterministic=deterministic
        )

        x = nn.Dense(
            self.embed_dim,
            dtype=self.dtype,
            use_bias=self.config.use_bias,
            kernel_init=deepnet_init(self.config.init_std, gain)
            if (self.config.use_deepnet_scaling or self.config.use_subln_init)
            else jax.nn.initializers.normal(self.config.init_std),
        )(x)
        if self.config.ln_positions in ["swinv2", "cogview"]:
            x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
        x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
        return x


class FFN(nn.Module):
    """Simple FFN layer"""

    config: DalleBartConfig
    ffn_dim: int
    embed_dim: int
    dtype: jnp.dtype = jnp.float32
    is_encoder: bool = False

    @nn.compact
    def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
        if self.config.use_deepnet_scaling:
            gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
                self.config
            )
        elif self.config.use_subln_init:
            gain = subln_gain["encoder" if self.is_encoder else "decoder"](self.config)
        if self.config.ln_positions in ["normformer", "cogview", "preln", "subln"]:
            x = norm(
                self.config.ln_type,
                dtype=self.dtype,
                epsilon=1e-05,
                use_scale=self.config.force_ln_scale,
            )(x)
        x = nn.Dense(
            self.ffn_dim,
            dtype=self.dtype,
            use_bias=self.config.use_bias,
            kernel_init=deepnet_init(self.config.init_std, gain)
            if (self.config.use_deepnet_scaling or self.config.use_subln_init)
            else jax.nn.initializers.normal(self.config.init_std),
        )(x)
        x = ACT2FN[self.config.activation_function](x)
        if self.config.ln_positions in ["normformer", "subln"]:
            x = norm(
                self.config.ln_type,
                dtype=self.dtype,
                epsilon=1e-05,
                use_scale=self.config.force_ln_scale,
            )(x)
        x = nn.Dropout(rate=self.config.activation_dropout)(
            x, deterministic=deterministic
        )
        x = nn.Dense(
            self.embed_dim,
            dtype=self.dtype,
            use_bias=self.config.use_bias,
            kernel_init=deepnet_init(self.config.init_std, gain)
            if (self.config.use_deepnet_scaling or self.config.use_subln_init)
            else jax.nn.initializers.normal(self.config.init_std),
        )(x)
        if self.config.ln_positions in ["swinv2", "cogview"]:
            x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
        x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
        return x


class FlaxBartEncoderLayer(nn.Module):
    """
    Edits:
    - no bias
    - use custom FlaxBartAttention
    """

    config: DalleBartConfig
    dtype: jnp.dtype = jnp.float32
    add_norm: bool = False
    use_scale: bool = True

    @nn.compact
    def __call__(
        self,
        hidden_states: jnp.ndarray,
        attention_mask: jnp.ndarray,
        output_attentions: bool = True,
        deterministic: bool = True,
    ) -> Tuple[jnp.ndarray]:
        if self.config.use_scan:
            hidden_states = hidden_states[0]

        res_gain = (
            deepnet_gain["encoder"]["alpha"](self.config)
            if self.config.use_deepnet_scaling
            else 1
        )

        embed_dim = self.config.d_model
        residual = hidden_states
        if self.config.ln_positions in ["normformer", "cogview", "preln", "subln"]:
            hidden_states = norm(
                self.config.ln_type,
                dtype=self.dtype,
                epsilon=1e-05,
                use_scale=self.config.force_ln_scale,
            )(hidden_states)
        hidden_states, attn_weights = FlaxBartAttention(
            config=self.config,
            embed_dim=embed_dim,
            num_heads=self.config.encoder_attention_heads,
            dropout=self.config.attention_dropout,
            bias=self.config.use_bias,
            dtype=self.dtype,
            is_encoder=True,
            is_cross_attention=False,
            q_length=self.config.max_text_length,
            k_length=self.config.max_text_length,
        )(hidden_states=hidden_states, attention_mask=attention_mask)

        if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
            hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
                hidden_states
            )
        hidden_states = nn.Dropout(rate=self.config.dropout)(
            hidden_states, deterministic=deterministic
        )
        hidden_states = residual * res_gain + hidden_states
        if self.config.ln_positions in ["postln"]:
            hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
                hidden_states
            )

        residual = hidden_states
        ff_block = (
            GLU(
                config=self.config,
                ffn_dim=self.config.encoder_ffn_dim,
                embed_dim=embed_dim,
                dtype=self.dtype,
                is_encoder=True,
            )
            if self.config.use_glu
            else FFN(
                config=self.config,
                ffn_dim=self.config.encoder_ffn_dim,
                embed_dim=embed_dim,
                dtype=self.dtype,
                is_encoder=True,
            )
        )
        hidden_states = ff_block(hidden_states, deterministic=deterministic)
        hidden_states = residual * res_gain + hidden_states
        if self.add_norm:
            use_scale = self.use_scale or self.config.force_ln_scale
            hidden_states = norm(
                self.config.ln_type,
                dtype=self.dtype,
                epsilon=1e-05,
                use_scale=use_scale,
            )(hidden_states)

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attn_weights,)

        if self.config.use_scan:
            outputs = (outputs, None)

        return outputs


class FlaxBartDecoderLayer(nn.Module):
    """
    Edits:
    - no bias
    - use custom FlaxBartAttention
    """

    config: DalleBartConfig
    dtype: jnp.dtype = jnp.float32
    add_norm: bool = False
    use_scale: bool = True

    @nn.compact
    def __call__(
        self,
        hidden_states: jnp.ndarray,
        attention_mask: jnp.ndarray,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        output_attentions: bool = True,
        deterministic: bool = True,
    ) -> Tuple[jnp.ndarray]:
        if self.config.use_scan:
            hidden_states = hidden_states[0]

        res_gain = (
            deepnet_gain["decoder"]["alpha"](self.config)
            if self.config.use_deepnet_scaling
            else 1
        )

        embed_dim = self.config.d_model
        residual = hidden_states

        # Self Attention
        if self.config.ln_positions in ["normformer", "cogview", "preln"]:
            hidden_states = norm(
                self.config.ln_type,
                dtype=self.dtype,
                epsilon=1e-05,
                use_scale=self.config.force_ln_scale,
            )(hidden_states)
        hidden_states, attn_weights = FlaxBartAttention(
            config=self.config,
            embed_dim=embed_dim,
            num_heads=self.config.decoder_attention_heads,
            dropout=self.config.attention_dropout,
            causal=True,
            bias=self.config.use_bias,
            dtype=self.dtype,
            is_encoder=False,
            is_cross_attention=False,
            q_length=self.config.image_length,
            k_length=self.config.image_length,
        )(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            init_cache=init_cache,
        )

        if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
            hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
                hidden_states
            )
        hidden_states = nn.Dropout(rate=self.config.dropout)(
            hidden_states, deterministic=deterministic
        )
        hidden_states = residual * res_gain + hidden_states
        if self.config.ln_positions in ["postln"]:
            hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
                hidden_states
            )

        # Cross Attention
        cross_attn_weights = None
        if encoder_hidden_states is not None:
            residual = hidden_states
            if self.config.ln_positions in ["normformer", "cogview", "preln"]:
                hidden_states = norm(
                    self.config.ln_type,
                    dtype=self.dtype,
                    epsilon=1e-05,
                    use_scale=self.config.force_ln_scale,
                )(hidden_states)
            hidden_states, cross_attn_weights = FlaxBartAttention(
                config=self.config,
                embed_dim=embed_dim,
                num_heads=self.config.decoder_attention_heads,
                dropout=self.config.attention_dropout,
                bias=self.config.use_bias,
                dtype=self.dtype,
                is_encoder=False,
                is_cross_attention=True,
                q_length=self.config.image_length,
                k_length=self.config.max_text_length,
            )(
                hidden_states=hidden_states,
                key_value_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
            )
            if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
                hidden_states = norm(
                    self.config.ln_type, dtype=self.dtype, epsilon=1e-05
                )(hidden_states)
            hidden_states = nn.Dropout(rate=self.config.dropout)(
                hidden_states, deterministic=deterministic
            )
            hidden_states = residual * res_gain + hidden_states
            if self.config.ln_positions in ["postln"]:
                hidden_states = norm(
                    self.config.ln_type, dtype=self.dtype, epsilon=1e-05
                )(hidden_states)

        # Feed forward
        residual = hidden_states
        ff_block = (
            GLU(
                config=self.config,
                ffn_dim=self.config.decoder_ffn_dim,
                embed_dim=embed_dim,
                dtype=self.dtype,
                is_encoder=False,
            )
            if self.config.use_glu
            else FFN(
                config=self.config,
                ffn_dim=self.config.decoder_ffn_dim,
                embed_dim=embed_dim,
                dtype=self.dtype,
                is_encoder=False,
            )
        )
        hidden_states = ff_block(hidden_states, deterministic=deterministic)
        hidden_states = residual * res_gain + hidden_states
        if self.add_norm:
            use_scale = self.use_scale or self.config.force_ln_scale
            hidden_states = norm(
                self.config.ln_type,
                dtype=self.dtype,
                epsilon=1e-05,
                use_scale=use_scale,
            )(hidden_states)

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attn_weights, cross_attn_weights)

        if self.config.use_scan:
            outputs = (outputs, None)

        return outputs


class FlaxBartEncoderLayerCollection(nn.Module):
    config: DalleBartConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    """
    Edits:
    - use custom FlaxBartEncoderLayer
    - allow Gradient Checkpointing (nn.remat)
    """

    @nn.compact
    def __call__(
        self,
        hidden_states,
        attention_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        n_layers = self.config.encoder_layers
        layer = (
            remat(
                FlaxBartEncoderLayer,
                static_argnums=(2, 3),
                prevent_cse=not self.config.use_scan,
            )
            if self.config.gradient_checkpointing
            else FlaxBartEncoderLayer
        )

        if self.config.use_scan:
            # all blocks are the same so we use nn.scan
            assert not output_attentions, "cannot scan with output_attentions"
            assert not output_hidden_states, "cannot scan with output_hidden_states"
            hidden_states = (hidden_states,)
            # we use a scale on all norms (even last layer) to allow scanning
            hidden_states, _ = nn.scan(
                layer,
                variable_axes={"params": 0, "cache": 0},
                split_rngs={"params": True, "dropout": True},
                in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
                length=n_layers,
            )(
                self.config,
                dtype=self.dtype,
                add_norm=self.config.ln_positions == "postln",
                name="FlaxBartEncoderLayers",
            )(
                hidden_states,
                attention_mask,
                output_attentions,
                deterministic,
            )
            hidden_states = hidden_states[0]
        else:
            for i in range(n_layers):
                if output_hidden_states:
                    all_hidden_states += (hidden_states,)
                # final layernorm on the output of the last layer
                # or every 6 layers for Swin v2
                add_norm = self.config.ln_positions == "postln" or (
                    self.config.ln_positions == "swinv2"
                    and ((i + 1) % 6 == 0)
                    and (i != n_layers - 1)
                )
                # we don't need to scale the norm for the last layer
                use_scale = i != n_layers - 1
                layer_outputs = layer(
                    self.config,
                    dtype=self.dtype,
                    add_norm=add_norm,
                    use_scale=use_scale,
                    name=f"FlaxBartEncoderLayer_{i}",
                )(
                    hidden_states,
                    attention_mask,
                    output_attentions,
                    deterministic,
                )
                hidden_states = layer_outputs[0]
                if output_attentions:
                    all_self_attns += (layer_outputs[1],)

            # add hidden states from the last layer
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

        outputs = [
            hidden_states,
            all_hidden_states,
            all_self_attns,
        ]

        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )


class FlaxBartDecoderLayerCollection(nn.Module):
    config: DalleBartConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    """
    Edits:
    - use custom FlaxBartDecoderLayer
    - allow Gradient Checkpointing (nn.remat)
    """

    @nn.compact
    def __call__(
        self,
        hidden_states,
        attention_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        all_cross_attentions = (
            () if (output_attentions and encoder_hidden_states is not None) else None
        )

        n_layers = self.config.decoder_layers
        layer = (
            remat(
                FlaxBartDecoderLayer,
                static_argnums=(4, 5, 6),
                prevent_cse=not self.config.use_scan,
            )
            if self.config.gradient_checkpointing
            else FlaxBartDecoderLayer
        )

        if self.config.use_scan:
            # all blocks are the same so we use nn.scan
            assert not output_attentions, "cannot scan with output_attentions"
            assert not output_hidden_states, "cannot scan with output_hidden_states"
            hidden_states = (hidden_states,)
            # we use a scale on all norms (even last layer) to allow scanning
            hidden_states, _ = nn.scan(
                layer,
                variable_axes={"params": 0, "cache": 0},
                split_rngs={"params": True, "dropout": True},
                in_axes=(
                    nn.broadcast,
                    nn.broadcast,
                    nn.broadcast,
                    nn.broadcast,
                    nn.broadcast,
                    nn.broadcast,
                ),
                length=n_layers,
            )(
                self.config,
                dtype=self.dtype,
                add_norm=self.config.ln_positions == "postln",
                name="FlaxBartDecoderLayers",
            )(
                hidden_states,
                attention_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                init_cache,
                output_attentions,
                deterministic,
            )
            hidden_states = hidden_states[0]

        else:
            for i in range(n_layers):
                if output_hidden_states:
                    all_hidden_states += (hidden_states,)
                # final layernorm on the output of the last layer
                # or every 6 layers for Swin v2
                add_norm = self.config.ln_positions == "postln" or (
                    self.config.ln_positions == "swinv2"
                    and ((i + 1) % 6 == 0)
                    and (i != n_layers - 1)
                )
                # we don't need to scale the norm for the last layer
                use_scale = i != n_layers - 1
                layer_outputs = layer(
                    self.config,
                    dtype=self.dtype,
                    add_norm=add_norm,
                    use_scale=use_scale,
                    name=f"FlaxBartDecoderLayer_{i}",
                )(
                    hidden_states,
                    attention_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    init_cache,
                    output_attentions,
                    deterministic,
                )

                hidden_states = layer_outputs[0]
                if output_attentions:
                    all_self_attns += (layer_outputs[1],)

                    if encoder_hidden_states is not None:
                        all_cross_attentions += (layer_outputs[2],)

            # add hidden states from the last decoder layer
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

        outputs = [
            hidden_states,
            all_hidden_states,
            all_self_attns,
            all_cross_attentions,
        ]

        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        return FlaxBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            cross_attentions=all_cross_attentions,
        )


class FlaxBartEncoder(nn.Module):
    config: DalleBartConfig
    embed_tokens: nn.Embed
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    """
    Edits:
    - offset set to 0 (no padding token)
    - use max_text_length instead of max_position_embeddings
    - use custom FlaxBartEncoderLayerCollection
    - embed_tokens cannot be None (issue at compile time)
    """

    def setup(self):
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)

        embed_dim = self.config.d_model
        self.padding_idx = self.config.pad_token_id
        self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0

        # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
        # and adjust num_embeddings appropriately. Other models don't have this hack
        self.offset = 0
        if self.config.use_absolute_position_embeddings:
            self.embed_positions = nn.Embed(
                self.config.max_text_length + self.offset,  # image length for BOS
                embed_dim,
                embedding_init=jax.nn.initializers.normal(self.config.init_std),
            )
        self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
        self.layernorm_embedding = norm(
            self.config.ln_type, dtype=self.dtype, epsilon=1e-05
        )

        # postln is already applied in every layer
        if self.config.use_final_ln_encoder and self.config.ln_positions != "postln":
            self.final_ln = norm(
                self.config.ln_type,
                dtype=self.dtype,
                epsilon=1e-05,
                use_scale=self.config.force_ln_scale,
            )
        else:
            self.final_ln = None

    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        input_shape = input_ids.shape
        input_ids = input_ids.reshape(-1, input_shape[-1])

        hidden_states = self.embed_tokens(input_ids) * self.embed_scale

        if self.config.use_absolute_position_embeddings:
            embed_pos = self.embed_positions(position_ids + self.offset)
            hidden_states = hidden_states + embed_pos

        hidden_states = self.layernorm_embedding(hidden_states)
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)

        outputs = self.layers(
            hidden_states,
            attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if self.final_ln is None:
            final_output = outputs[0]
        else:
            final_output = self.final_ln(outputs[0])

        if not return_dict:
            return (final_output,) + outputs[1:]

        return FlaxBaseModelOutput(
            last_hidden_state=final_output,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


class FlaxBartDecoder(nn.Module):
    config: DalleBartConfig
    embed_tokens: nn.Embed
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    """
    Edits:
    - offset set to 0 (no padding token)
    - use image_length instead of max_position_embeddings
    - use custom FlaxBartDecoderLayerCollection
    - embed_tokens cannot be None (issue at compile time)
    """

    def setup(self):
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)

        embed_dim = self.config.d_model
        self.padding_idx = self.config.pad_token_id
        self.embed_scale = (
            math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
        )

        # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
        # and adjust num_embeddings appropriately. Other models don't have this hack
        self.offset = 0
        if self.config.use_absolute_position_embeddings:
            self.embed_positions = nn.Embed(
                self.config.image_length + self.offset,  # image length for BOS
                embed_dim,
                embedding_init=jax.nn.initializers.normal(self.config.init_std),
            )

        self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
        self.layernorm_embedding = norm(
            self.config.ln_type, dtype=self.dtype, epsilon=1e-05
        )

        # postln is already applied in every layer
        if self.config.use_final_ln_decoder and self.config.ln_positions != "postln":
            self.final_ln = norm(
                self.config.ln_type,
                dtype=self.dtype,
                epsilon=1e-05,
                use_scale=self.config.force_ln_scale,
            )

    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        input_shape = input_ids.shape
        input_ids = input_ids.reshape(-1, input_shape[-1])

        hidden_states = self.embed_tokens(input_ids) * self.embed_scale

        if self.config.use_absolute_position_embeddings:
            embed_pos = self.embed_positions(position_ids + self.offset)
            hidden_states = hidden_states + embed_pos

        hidden_states = self.layernorm_embedding(hidden_states)
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)

        outputs = self.layers(
            hidden_states,
            attention_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if self.final_ln is None:
            final_output = outputs[0]
        else:
            final_output = self.final_ln(outputs[0])

        if not return_dict:
            return (final_output,) + outputs[1:]

        return FlaxBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=final_output,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )


class FlaxBartModule(FlaxBartModule):
    """
    Edits
    - use custom FlaxBartEncoder & FlaxBartDecoder
    - use separate embeddings for Encoder & Decoder
    """

    def setup(self):
        encoder_embed_tokens = nn.Embed(
            self.config.encoder_vocab_size,
            self.config.d_model,
            embedding_init=jax.nn.initializers.normal(self.config.init_std),
        )
        decoder_embed_tokens = nn.Embed(
            self.config.image_vocab_size + 1,  # image vocab size + 1 for BOS
            self.config.d_model,
            embedding_init=jax.nn.initializers.normal(self.config.init_std),
        )

        self.encoder = FlaxBartEncoder(
            self.config, dtype=self.dtype, embed_tokens=encoder_embed_tokens
        )
        self.decoder = FlaxBartDecoder(
            self.config, dtype=self.dtype, embed_tokens=decoder_embed_tokens
        )


class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
    """
    Edits:
    - no bias
    - lm_head set to image_vocab_size + 1 (for BOS)
    - uses custom FlaxBartModule
    """

    def setup(self):
        self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
        self.lm_head = nn.Dense(
            self.config.image_vocab_size
            + 1,  # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )

    def __call__(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        position_ids,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            position_ids=position_ids,
            decoder_position_ids=decoder_position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        hidden_states = outputs[0]

        if self.config.tie_word_embeddings:
            shared_embedding = self.model.variables["params"]["shared"]["embedding"]
            lm_logits = self.lm_head.apply(
                {"params": {"kernel": shared_embedding.T}}, hidden_states
            )
        else:
            lm_logits = self.lm_head(hidden_states)

        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return output

        return FlaxSeq2SeqLMOutput(
            logits=lm_logits,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )


@flax.struct.dataclass
class SampleState:
    cur_len: jnp.ndarray
    sequences: jnp.ndarray
    running_token: jnp.ndarray
    is_sent_finished: jnp.ndarray
    prng_key: jnp.ndarray
    model_kwargs: Dict[str, jnp.ndarray]
    model_kwargs_uncond: Dict[str, jnp.ndarray]


@flax.struct.dataclass
class FlaxSampleOutput(ModelOutput):
    """
    Flax Base class for outputs of decoder-only generation models using sampling.


    Args:
        sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
            The generated sequences.
    """

    sequences: jnp.ndarray = None


class DalleBart(PretrainedFromWandbMixin, FlaxBartForConditionalGeneration):
    """
    Edits:
    - renamed from FlaxBartForConditionalGeneration
    - uses custom FlaxBartForConditionalGenerationModule
    - no bias in decode method
    - custom prepare_inputs_for_generation using "max_length - 1" to avoid issues
      related to position embedding during model.generate()
    - custom generate method to allow super conditions
    - num_params property
    - unscan function
    """

    module_class = FlaxBartForConditionalGenerationModule
    config_class = DalleBartConfig

    def num_params(self, params=None):
        if params is None:
            params = self.params
        num_params = jax.tree_util.tree_map(
            lambda param: param.size, flatten_dict(unfreeze(params))
        ).values()
        return sum(list(num_params))

    def unscan(self, params):
        if self.config.use_scan:
            self.config.use_scan = False
            params = flatten_dict(params)
            scanned_keys = [k for k in params.keys() if "layers" in k]
            for k in scanned_keys:
                v = params[k]
                name_idx = k.index("layers") + 1
                for i in range(len(v)):
                    new_k = (
                        *k[:name_idx],
                        f"{k[name_idx][:-1]}_{i}",
                        *k[name_idx + 1 :],
                    )
                    params[new_k] = v[i]
                del params[k]
            params = unflatten_dict(params)
        return params

    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.return_dict
        )

        encoder_hidden_states = encoder_outputs[0]
        if encoder_attention_mask is None:
            batch_size, sequence_length = encoder_hidden_states.shape[:2]
            encoder_attention_mask = jnp.ones((batch_size, sequence_length))

        batch_size, sequence_length = decoder_input_ids.shape
        if decoder_attention_mask is None:
            decoder_attention_mask = jnp.ones((batch_size, sequence_length))

        if decoder_position_ids is None:
            if past_key_values is not None:
                raise ValueError(
                    "Make sure to provide `decoder_position_ids` when passing `past_key_values`."
                )

            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
            )

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        inputs = {"params": params or self.params}

        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
        # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
        # it can be changed by FlaxBartAttention module
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        def _decoder_forward(
            module,
            decoder_input_ids,
            decoder_attention_mask,
            decoder_position_ids,
            **kwargs,
        ):
            decoder_module = module._get_decoder_module()
            outputs = decoder_module(
                decoder_input_ids,
                decoder_attention_mask,
                decoder_position_ids,
                **kwargs,
            )
            hidden_states = outputs[0]

            if self.config.tie_word_embeddings:
                shared_embedding = module.model.variables["params"]["shared"][
                    "embedding"
                ]
                lm_logits = module.lm_head.apply(
                    {"params": {"kernel": shared_embedding.T}}, hidden_states
                )
            else:
                lm_logits = module.lm_head(hidden_states)

            return lm_logits, outputs

        outputs = self.module.apply(
            inputs,
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
            decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
            mutable=mutable,
            method=_decoder_forward,
        )

        if past_key_values is None:
            lm_logits, decoder_outputs = outputs
        else:
            (lm_logits, decoder_outputs), past = outputs

        if return_dict:
            outputs = FlaxCausalLMOutputWithCrossAttentions(
                logits=lm_logits,
                hidden_states=decoder_outputs.hidden_states,
                attentions=decoder_outputs.attentions,
                cross_attentions=decoder_outputs.cross_attentions,
            )
        else:
            outputs = (lm_logits,) + decoder_outputs[1:]

        # add updated cache to model output
        if past_key_values is not None and return_dict:
            outputs["past_key_values"] = unfreeze(past["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]

        return outputs

    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        max_length,
        attention_mask: Optional[jnp.DeviceArray] = None,
        decoder_attention_mask: Optional[jnp.DeviceArray] = None,
        encoder_outputs=None,
        **kwargs,
    ):
        # initializing the cache
        batch_size, seq_length = decoder_input_ids.shape

        past_key_values = self.init_cache(batch_size, max_length - 1, encoder_outputs)
        # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
        # But since the decoder uses a causal mask, those positions are masked anyways.
        # Thus we can create a single static attention_mask here, which is more efficient for compilation
        extended_attention_mask = jnp.ones((batch_size, max_length - 1), dtype="i4")
        if decoder_attention_mask is not None:
            position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(
                extended_attention_mask, decoder_attention_mask, (0, 0)
            )
        else:
            position_ids = jnp.broadcast_to(
                jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
            )

        return {
            "past_key_values": past_key_values,
            "encoder_outputs": encoder_outputs,
            "encoder_attention_mask": attention_mask,
            "decoder_attention_mask": extended_attention_mask,
            "decoder_position_ids": position_ids,
        }

    def generate(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        bos_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        decoder_start_token_id: Optional[int] = None,
        do_sample: Optional[bool] = None,
        prng_key: Optional[jnp.ndarray] = None,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        temperature: Optional[float] = None,
        num_beams: Optional[int] = None,
        no_repeat_ngram_size: Optional[int] = None,
        min_length: Optional[int] = None,
        forced_bos_token_id: Optional[int] = None,
        forced_eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        early_stopping: Optional[bool] = None,
        trace: bool = True,
        params: Optional[Dict[str, jnp.ndarray]] = None,
        condition_scale: Optional[float] = 1.0,
        input_ids_uncond: Optional[jnp.ndarray] = None,
        attention_mask_uncond: Optional[jnp.ndarray] = None,
        **model_kwargs,
    ):
        """Edit: Allow super conditioning."""

        # set init values
        max_length = max_length if max_length is not None else self.config.max_length
        bos_token_id = (
            bos_token_id if bos_token_id is not None else self.config.bos_token_id
        )
        pad_token_id = (
            pad_token_id if pad_token_id is not None else self.config.pad_token_id
        )
        eos_token_id = (
            eos_token_id if eos_token_id is not None else self.config.eos_token_id
        )
        decoder_start_token_id = (
            decoder_start_token_id
            if decoder_start_token_id
            else self.config.decoder_start_token_id
        )
        prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)

        if decoder_start_token_id is None and self.config.is_encoder_decoder:
            raise ValueError(
                "`decoder_start_token_id` has to be defined for encoder-decoder generation."
            )

        do_sample = do_sample if do_sample is not None else self.config.do_sample
        num_beams = num_beams if num_beams is not None else self.config.num_beams

        if self.config.is_encoder_decoder:
            # add encoder_outputs to model_kwargs
            if model_kwargs.get("encoder_outputs") is None:
                model_kwargs_input = dict(model_kwargs)
                model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
                    input_ids,
                    params,
                    {"attention_mask": attention_mask, **model_kwargs_input},
                )
                if condition_scale != 1.0:
                    assert (
                        input_ids_uncond is not None
                    ), "`input_ids_uncond` has to be defined for super conditioning."
                    assert (
                        do_sample is True
                    ), "`do_sample` has to be True for super conditioning."
                    assert (
                        num_beams == 1
                    ), "`num_beams` has to be 1 for super conditioning."
                    model_kwargs_uncond = (
                        self._prepare_encoder_decoder_kwargs_for_generation(
                            input_ids_uncond,
                            params,
                            {
                                "attention_mask": attention_mask_uncond,
                                **model_kwargs_input,
                            },
                        )
                    )
                else:
                    model_kwargs_uncond = None
            # prepare decoder_input_ids for generation
            input_ids = (
                jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
            )

        if not do_sample and num_beams == 1:
            logits_processor = self._get_logits_processor(
                no_repeat_ngram_size,
                min_length,
                max_length,
                eos_token_id,
                forced_bos_token_id,
                forced_eos_token_id,
            )
            return self._greedy_search(
                input_ids,
                max_length,
                pad_token_id,
                eos_token_id,
                logits_processor=logits_processor,
                trace=trace,
                params=params,
                model_kwargs=model_kwargs,
            )
        elif do_sample and num_beams == 1:
            logits_warper = self._get_logits_warper(
                top_k=top_k, top_p=top_p, temperature=temperature
            )
            logits_processor = self._get_logits_processor(
                no_repeat_ngram_size,
                min_length,
                max_length,
                eos_token_id,
                forced_bos_token_id,
                forced_eos_token_id,
            )
            return self._sample(
                input_ids,
                max_length,
                pad_token_id,
                eos_token_id,
                prng_key,
                logits_warper=logits_warper,
                logits_processor=logits_processor,
                trace=trace,
                params=params,
                model_kwargs=model_kwargs,
                condition_scale=condition_scale,
                model_kwargs_uncond=model_kwargs_uncond,
            )
        elif not do_sample and num_beams > 1:
            # broadcast input_ids & encoder_outputs
            input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)

            if "encoder_outputs" in model_kwargs:
                model_kwargs["encoder_outputs"][
                    "last_hidden_state"
                ] = self._expand_to_num_beams(
                    model_kwargs["encoder_outputs"]["last_hidden_state"],
                    num_beams=num_beams,
                )

            if "attention_mask" in model_kwargs:
                model_kwargs["attention_mask"] = self._expand_to_num_beams(
                    model_kwargs["attention_mask"], num_beams=num_beams
                )

            logits_processor = self._get_logits_processor(
                no_repeat_ngram_size,
                min_length,
                max_length,
                eos_token_id,
                forced_bos_token_id,
                forced_eos_token_id,
            )

            return self._beam_search(
                input_ids,
                max_length,
                pad_token_id,
                eos_token_id,
                length_penalty=length_penalty,
                early_stopping=early_stopping,
                logits_processor=logits_processor,
                trace=trace,
                params=params,
                model_kwargs=model_kwargs,
            )
        else:
            raise NotImplementedError("`Beam sampling is currently not implemented.")

    def _sample(
        self,
        input_ids: None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        prng_key: Optional[jnp.ndarray] = None,
        logits_processor=None,
        logits_warper=None,
        trace: bool = True,
        params: Optional[Dict[str, jnp.ndarray]] = None,
        model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
        condition_scale: float = 1.0,
        model_kwargs_uncond: Optional[Dict[str, jnp.ndarray]] = None,
    ):
        # init values
        max_length = max_length if max_length is not None else self.config.max_length
        pad_token_id = (
            pad_token_id if pad_token_id is not None else self.config.pad_token_id
        )
        eos_token_id = (
            eos_token_id if eos_token_id is not None else self.config.eos_token_id
        )
        prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)

        batch_size, cur_len = input_ids.shape

        eos_token_id = jnp.array(eos_token_id)
        pad_token_id = jnp.array(pad_token_id)
        cur_len = jnp.array(cur_len)

        # per batch-item holding current token in loop.
        sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
        sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))

        # per batch-item state bit indicating if sentence has finished.
        is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)

        # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
        # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
        model = self.decode if self.config.is_encoder_decoder else self

        # initialize model specific kwargs
        model_kwargs = self.prepare_inputs_for_generation(
            input_ids, max_length, **model_kwargs
        )
        if condition_scale != 1.0:
            model_kwargs_uncond = self.prepare_inputs_for_generation(
                input_ids, max_length, **model_kwargs_uncond
            )

        # initialize state
        state = SampleState(
            cur_len=cur_len,
            sequences=sequences,
            running_token=input_ids,
            is_sent_finished=is_sent_finished,
            prng_key=prng_key,
            model_kwargs=model_kwargs,
            model_kwargs_uncond=model_kwargs_uncond,
        )

        def sample_search_cond_fn(state):
            """state termination condition fn."""
            has_reached_max_length = state.cur_len == max_length
            all_sequence_finished = jnp.all(state.is_sent_finished)
            finish_generation = jnp.logical_or(
                has_reached_max_length, all_sequence_finished
            )
            return ~finish_generation

        def sample_search_body_fn(state):
            """state update fn."""
            prng_key, prng_key_next = jax.random.split(state.prng_key)
            model_outputs = model(
                state.running_token, params=params, **state.model_kwargs
            )

            logits = model_outputs.logits[:, -1]

            # perform super conditioning
            # Source: @RiversHaveWings - https://twitter.com/RiversHaveWings/status/1478093658716966912?s=20&t=xdm-wZ61Wf7OLnE_NJHZ1w
            if condition_scale != 1.0:
                model_outputs_uncond = model(
                    state.running_token, params=params, **state.model_kwargs_uncond
                )
                logits_uncond = model_outputs_uncond.logits[:, -1]
                logits = logits_uncond + condition_scale * (logits - logits_uncond)
            else:
                model_outputs_uncond = None

            # apply min_length, ...
            logits = logits_processor(state.sequences, logits, state.cur_len)
            # apply top_k, top_k, temperature
            logits = logits_warper(logits, logits, state.cur_len)

            next_token = jax.random.categorical(prng_key, logits, axis=-1)

            next_is_sent_finished = state.is_sent_finished | (
                next_token == eos_token_id
            )
            next_token = (
                next_token * ~next_is_sent_finished
                + pad_token_id * next_is_sent_finished
            )
            next_token = next_token[:, None]

            next_sequences = lax.dynamic_update_slice(
                state.sequences, next_token, (0, state.cur_len)
            )
            next_model_kwargs = self.update_inputs_for_generation(
                model_outputs, state.model_kwargs
            )
            next_model_kwargs_uncond = (
                self.update_inputs_for_generation(
                    model_outputs_uncond, state.model_kwargs_uncond
                )
                if condition_scale != 1.0
                else None
            )

            return SampleState(
                cur_len=state.cur_len + 1,
                sequences=next_sequences,
                running_token=next_token,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
                model_kwargs_uncond=next_model_kwargs_uncond,
                prng_key=prng_key_next,
            )

        # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
        if input_ids.shape[1] > 1:
            state = sample_search_body_fn(state)

        if not trace:
            state = self._run_loop_in_debug(
                sample_search_cond_fn, sample_search_body_fn, state
            )
        else:
            state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)

        return FlaxSampleOutput(sequences=state.sequences)


================================================
FILE: src/dalle_mini/model/partitions.py
================================================
import re

from flax.core.frozen_dict import freeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.experimental import PartitionSpec as P

# utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
# Sentinels
_unmatched = object()

# For specifying empty leaf dict `{}`
empty_dict = object()


def _match(qs, ks):
    """Return True if regexes in qs match any window of strings in tuple ks."""
    # compile regexes and force complete match
    qts = tuple(map(lambda x: re.compile(x + "$"), qs))
    for i in range(len(ks) - len(qs) + 1):
        matches = [x.match(y) for x, y in zip(qts, ks[i:])]
        if matches and all(matches):
            return True
    return False


def _replacement_rules(rules):
    def replace(key, val):
        for rule, replacement in rules:
            if _match(rule, key):
                return replacement
        return val

    return replace


def _get_partition_rules():
    return [
        # embeddings
        (("embed_positions", "embedding"), P("mp", None)),
        (("embed_tokens", "embedding"), P("mp", None)),
        (("rel_bias", "embedding"), P(None, "mp")),
        # attention
        (("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
        (("out_proj", "kernel"), P("mp", None)),
        # FFN
        (("Dense_0", "kernel"), P(None, "mp")),
        (("GLU.*", "Dense_1", "kernel"), P(None, "mp")),
        (("GLU.*", "Dense_2", "kernel"), P("mp", None)),
        (("FFN.*", "Dense_1", "kernel"), P("mp", None)),
        # layer norms
        (("(bias|scale)",), None),
        (("lm_head", "kernel"), P(None, "mp")),
        # head scale and tau
        (("(head_scale|tau)",), None),
    ]


def set_partitions(in_dict, use_scan):
    rules = _get_partition_rules()
    replace = _replacement_rules(rules)
    initd = {k: _unmatched for k in flatten_dict(in_dict)}
    result = {k: replace(k, v) for k, v in initd.items()}
    for k, v in result.items():
        if v == _unmatched:
            print(f"Unmatched -> {k}")
    l = list(result.keys())
    if use_scan:
        # add None dimension to layers
        result = {
            k: (P(*(None,) + v) if v is not None else None)
            if any(x in k for x in ["FlaxBartEncoderLayers", "FlaxBartDecoderLayers"])
            else v
            for k, v in result.items()
        }
    assert _unmatched not in result.values(), "Incomplete partition spec."
    return freeze(unflatten_dict(result))


================================================
FILE: src/dalle_mini/model/processor.py
================================================
""" DalleBart processor """

from typing import List

import jax.numpy as jnp

from .configuration import DalleBartConfig
from .text import TextNormalizer
from .tokenizer import DalleBartTokenizer
from .utils import PretrainedFromWandbMixin


class DalleBartProcessorBase:
    def __init__(
        self, tokenizer: DalleBartTokenizer, normalize_text: bool, max_text_length: int
    ):
        self.tokenizer = tokenizer
        self.normalize_text = normalize_text
        self.max_text_length = max_text_length
        if normalize_text:
            self.text_processor = TextNormalizer()
        # create unconditional tokens
        uncond = self.tokenizer(
            "",
            return_tensors="jax",
            padding="max_length",
            truncation=True,
            max_length=self.max_text_length,
        ).data
        self.input_ids_uncond = uncond["input_ids"]
        self.attention_mask_uncond = uncond["attention_mask"]

    def __call__(self, text: List[str] = None):
        # check that text is not a string
        assert not isinstance(text, str), "text must be a list of strings"

        if self.normalize_text:
            text = [self.text_processor(t) for t in text]
        res = self.tokenizer(
            text,
            return_tensors="jax",
            padding="max_length",
            truncation=True,
            max_length=self.max_text_length,
        ).data
        # tokens used only with super conditioning
        n = len(text)
        res["input_ids_uncond"] = jnp.repeat(self.input_ids_uncond, n, axis=0)
        res["attention_mask_uncond"] = jnp.repeat(self.attention_mask_uncond, n, axis=0)
        return res

    @classmethod
    def from_pretrained(cls, *args, **kwargs):
        tokenizer = DalleBartTokenizer.from_pretrained(*args, **kwargs)
        config = DalleBartConfig.from_pretrained(*args, **kwargs)
        return cls(tokenizer, config.normalize_text, config.max_text_length)


class DalleBartProcessor(PretrainedFromWandbMixin, DalleBartProcessorBase):
    pass


================================================
FILE: src/dalle_mini/model/text.py
================================================
"""
Utilities for processing text.
"""

import html
import math
import random
import re
from pathlib import Path

import emoji
import ftfy
from huggingface_hub import hf_hub_download
from unidecode import unidecode

# based on wiki word occurrence
person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
temp_token = "xtokx"  # avoid repeating chars


class HashtagProcessor:
    # Adapted from wordninja library
    # We use our wikipedia word count + a good heuristic to make it work
    def __init__(self):
        wiki_word_frequency = hf_hub_download(
            "dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt"
        )
        self._word_cost = (
            l.split()[0]
            for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines()
        )
        self._word_cost = {
            str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost)
        }
        self._max_word = max(len(x) for x in self._word_cost.keys())
        self._SPLIT_RE = re.compile("[^a-zA-Z0-9']+")

    def __call__(self, s):
        """Uses dynamic programming to infer the location of spaces in a string without spaces."""
        l = [self._split(x) for x in self._SPLIT_RE.split(s)]
        return " ".join([item for sublist in l for item in sublist])

    def _split(self, s):
        # Find the best match for the i first characters, assuming cost has
        # been built for the i-1 first characters.
        # Returns a pair (match_cost, match_length).
        def best_match(i):
            candidates = enumerate(reversed(cost[max(0, i - self._max_word) : i]))
            return min(
                (c + self._word_cost.get(s[i - k - 1 : i].lower(), 9e999), k + 1)
                for k, c in candidates
            )

        # Build the cost array
        cost = [0]
        for i in range(1, len(s) + 1):
            c, k = best_match(i)
            cost.append(c)

        # Backtrack to recover the minimal-cost string.
        out = []
        i = len(s)
        while i > 0:
            c, k = best_match(i)
            assert c == cost[i]
            newToken = True
            if not s[i - k : i] == "'":  # ignore a lone apostrophe
                if len(out) > 0:
                    # re-attach split 's and split digits
                    if out[-1] == "'s" or (
                        s[i - 1].isdigit() and out[-1][0].isdigit()
                    ):  # digit followed by digit
                        out[-1] = (
                            s[i - k : i] + out[-1]
                        )  # combine current token with previous token
                        newToken = False

            if newToken:
                out.append(s[i - k : i])

            i -= k

        return reversed(out)


def replace_person_token(t):
    "Used for CC12M"
    t = re.sub("<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t)
    while "<person>" in t:
        t = t.replace(
            "<person>", f" {random.choices(*tuple(zip(*person_token)))[0]} ", 1
        )
    return t


def fix_html(t):
    # from OpenAI CLIP
    return html.unescape(html.unescape(t))


def replace_punctuation_with_commas(t):
    return re.sub("[()[\].,|:;?!=+~\-\/{}]", ",", t)


def simplify_quotes(t):
    return re.sub("""['"`]""", ' " ', t)


def merge_quotes(t):
    return re.sub('(\s*"+\s*)+', ' " ', t)


def remove_comma_numbers(t):
    def _f(t):
        return re.sub("(\d),(\d{3})", r"\1\2", t)

    return _f(_f(t))


def pre_process_dot_numbers(t):
    return re.sub("(\w)\.(\w)", rf"\1{temp_token}dot{temp_token}\2", t)


def post_process_dot_numbers(t):
    return re.sub(f"{temp_token}dot{temp_token}", ".", t)


def pre_process_quotes(t):
    # allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
    return re.sub(
        r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", rf"{temp_token}quote{temp_token}", t
    )


def post_process_quotes(t):
    return re.sub(f"{temp_token}quote{temp_token}", "'", t)


def pre_process_dates(t):
    return re.sub("(\d)/(\d)", rf"\1{temp_token}slash{temp_token}\2", t)


def post_process_dates(t):
    return re.sub(f"{temp_token}slash{temp_token}", "/", t)


def merge_commas(t):
    return re.sub("(\s*,+\s*)+", ", ", t)


def add_space_after_commas(t):
    return re.sub(",", ", ", t)


def handle_special_chars(t):
    "Handle special characters"
    # replace "-" with a space when between words without space
    t = re.sub("(\w)-(\w)", r"\1 \2", t)
    # always add space around some characters
    return re.sub("([%&\/$*])", r" \1 ", t)


def expand_hashtags(t, hashtag_processor):
    "Remove # and try to split words"
    return re.sub("#(\w+)", lambda m: hashtag_processor(m.group(1)), t)


_re_ignore_chars = r"[_#\\]"


def ignore_chars(t):
    "Ignore useless characters"
    return re.sub(_re_ignore_chars, " ", t)


def remove_extra_spaces(t):
    "Remove extra spaces (including \t and \n)"
    return re.sub("\s+", " ", t)


def remove_repeating_chars(t):
    "If the same character is present 4+ times (not 3 because of roman 'VIII'), replace with single instance"
    return re.sub(r"(\D)(\1{3,})", r"\1", t)


def remove_urls(t):
    return re.sub(r"http\S+", "", t)


def remove_html_tags(t):
    return re.sub("<[^<]+?>", " ", t)


def remove_first_last_commas(t):
    t = t.strip()
    t = t[:-1] if t and t[-1] == "," else t
    t = t[1:] if t and t[0] == "," else t
    return t.strip()


def remove_wiki_ref(t):
    t = re.sub(r"\A\s*\[\d+\]", "", t)
    return re.sub(r"\[\d+\]\s*\Z", "", t)


class TextNormalizer:
    "Normalize text"

    def __init__(self):
        self._hashtag_processor = HashtagProcessor()

    def __call__(self, t):
        # fix some characters
        t = ftfy.fix_text(t)
        # fix html
        t = fix_html(t)
        # decode emojis (would be removed by unidecode)
        t = emoji.demojize(t)
        # decode and simplify text: see unidecode library
        t = unidecode(t)
        # lower case
        t = t.lower()
        # replace <PERSON> (for CC12M)
        t = replace_person_token(t)
        # remove wiki reference (for WIT)
        t = remove_wiki_ref(t)
        # remove html tags
        t = remove_html_tags(t)
        # remove urls
        t = remove_urls(t)
        # remove commas in numbers
        t = remove_comma_numbers(t)
        # handle dots in numbers and quotes - Part 1
        t = pre_process_dot_numbers(t)
        t = pre_process_quotes(t)
        t = pre_process_dates(t)
        # handle special characters
        t = handle_special_chars(t)
        # handle hashtags
        t = expand_hashtags(t, self._hashtag_processor)
        # ignore useless characters
        t = ignore_chars(t)
        # simplify quotes
        t = simplify_quotes(t)
        # all punctuation becomes commas
        t = replace_punctuation_with_commas(t)
        # handle dots in numbers and quotes - Part 2
        t = post_process_dot_numbers(t)
        t = post_process_quotes(t)
        t = post_process_dates(t)
        # handle repeating characters
        t = remove_repeating_chars(t)
        # merge quotes
        t = merge_quotes(t)
        # merge commas
        t = merge_commas(t)
        # remove multiple spaces
        t = remove_extra_spaces(t)
        # remove first and last comma
        t = remove_first_last_commas(t)
        # always start with a space
        return f" {t}"


================================================
FILE: src/dalle_mini/model/tokenizer.py
================================================
""" DalleBart tokenizer """
from transformers import BartTokenizerFast

from .utils import PretrainedFromWandbMixin


class DalleBartTokenizer(PretrainedFromWandbMixin, BartTokenizerFast):
    pass


================================================
FILE: src/dalle_mini/model/utils.py
================================================
import os
import tempfile
from pathlib import Path

import wandb


class PretrainedFromWandbMixin:
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        """
        Initializes from a wandb artifact or delegates loading to the superclass.
        """
        with tempfile.TemporaryDirectory() as tmp_dir:  # avoid multiple artifact copies
            if ":" in pretrained_model_name_or_path and not os.path.isdir(
                pretrained_model_name_or_path
            ):
                # wandb artifact
                if wandb.run is not None:
                    artifact = wandb.run.use_artifact(pretrained_model_name_or_path)
                else:
                    artifact = wandb.Api().artifact(pretrained_model_name_or_path)
                pretrained_model_name_or_path = artifact.download(tmp_dir)

            return super(PretrainedFromWandbMixin, cls).from_pretrained(
                pretrained_model_name_or_path, *model_args, **kwargs
            )


================================================
FILE: tools/dataset/encode_dataset.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d0b72877",
   "metadata": {},
   "source": [
    "# Pre-encoding a dataset for DALLE·mini"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba7b31e6",
   "metadata": {},
   "source": [
    "This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n",
    "\n",
    "Adapt it to your own dataset and image encoder.\n",
    "\n",
    "At the end you should have a dataset of pairs:\n",
    "* a caption defined as a string\n",
    "* an encoded image defined as a list of int."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b59489e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "\n",
    "import torchvision.transforms as T\n",
    "\n",
    "import webdataset as wds\n",
    "\n",
    "import jax\n",
    "import braceexpand\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7c4c1e6",
   "metadata": {},
   "source": [
    "## Configuration Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1265dbfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "shards = \"my_images/shard-{0000..0008}.tar\"  # defined using braceexpand format as used by webdataset\n",
    "encoded_output = Path(\"encoded_data\")  # where we will save our encoded data\n",
    "\n",
    "VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
    "    \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
    "    \"85eb5d3b51a1c62a0cc8f4ccdee9882c0d0bd384\",\n",
    ")\n",
    "\n",
    "# good defaults for a TPU v3-8\n",
    "batch_size = 128  # Per device\n",
    "num_workers = 8  # For parallel processing\n",
    "total_bs = batch_size * jax.device_count()  # You can use a smaller size while testing\n",
    "save_frequency = 128  # Number of batches to create a new file (180MB for f16 and 720MB for f8 per file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cd956ec6-7d98-4d4d-a454-f80fe857eadd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['XXX/shard-0000.tar',\n",
       " 'XXX/shard-0001.tar',\n",
       " 'XXX/shard-0002.tar',\n",
       " 'XXX/shard-0003.tar',\n",
       " 'XXX/shard-0004.tar',\n",
       " 'XXX/shard-0005.tar',\n",
       " 'XXX/shard-0006.tar',\n",
       " 'XXX/shard-0007.tar',\n",
       " 'XXX/shard-0008.tar']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "shards = list(\n",
    "    braceexpand.braceexpand(shards)\n",
    ")  # better display for tqdm with known length"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75dba8e2",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1e8fb95",
   "metadata": {},
   "source": [
    "We load data using `webdataset`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ef5de9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = (\n",
    "    wds.WebDataset(shards, handler=wds.warn_and_continue)\n",
    "    .decode(\"rgb\", handler=wds.warn_and_continue)\n",
    "    .to_tuple(\"jpg\", \"txt\")  # assumes image is in `jpg` and caption in `txt`\n",
    "    .batched(total_bs)  # load in batch per worker (faster)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90981824",
   "metadata": {},
   "source": [
    "Note:\n",
    "* you can also shuffle shards and items using `shardshuffle` and `shuffle` if necessary.\n",
    "* you may need to resize images in your pipeline (with `map_dict` for example), we assume they are already set to 256x256.\n",
    "* you can also filter out some items using `select`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "129c377d",
   "metadata": {},
   "source": [
    "We can now inspect our data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cac98cb",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "images, captions = next(iter(ds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd268fbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "images.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5acfc4d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "captions[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c24693c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "T.ToPILImage()(images[0].permute(2, 0, 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3059ffb1",
   "metadata": {},
   "source": [
    "Finally we create our dataloader."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c227c551",
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = (\n",
    "    wds.WebLoader(ds, batch_size=None, num_workers=8).unbatched().batched(total_bs)\n",
    ")  # avoid partial batch at the end of each worker"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a354472b",
   "metadata": {},
   "source": [
    "## Image encoder\n",
    "\n",
    "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47a8b818",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
    "from flax.jax_utils import replicate\n",
    "\n",
    "vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")\n",
    "vqgan_params = replicate(vqgan.params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62ad01c3",
   "metadata": {},
   "source": [
    "## Encoding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20357f74",
   "metadata": {},
   "source": [
    "Encoding is really simple using `shard` to automatically distribute batches across devices and `pmap`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "322a4619",
   "metadata": {},
   "outputs": [],
   "source": [
    "from flax.training.common_utils import shard\n",
    "from functools import partial\n",
    "\n",
    "\n",
    "@partial(jax.pmap, axis_name=\"batch\")\n",
    "def p_encode(batch, params):\n",
    "    # Not sure if we should `replicate` params, does not seem to have any effect\n",
    "    _, indices = vqgan.encode(batch, params=params)\n",
    "    return indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff6c10d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "\n",
    "def encode_dataset(dataloader, output_dir, save_frequency):\n",
    "    output_dir.mkdir(parents=True, exist_ok=True)\n",
    "    all_captions = []\n",
    "    all_encoding = []\n",
    "    n_file = 1\n",
    "    for idx, (images, captions) in enumerate(tqdm(dataloader)):\n",
    "        images = images.numpy()\n",
    "        n = len(images) // 8 * 8\n",
    "        if n != len(images):\n",
    "            # get the max number of images we can (multiple of 8)\n",
    "            print(f\"Different sizes {n} vs {len(images)}\")\n",
    "            images = images[:n]\n",
    "            captions = captions[:n]\n",
    "        if not len(captions):\n",
    "            print(f\"No images/captions in batch...\")\n",
    "            continue\n",
    "        images = shard(images)\n",
    "        encoded = p_encode(images, vqgan_params)\n",
    "        encoded = encoded.reshape(-1, encoded.shape[-1])\n",
    "        all_captions.extend(captions)\n",
    "        all_encoding.extend(encoded.tolist())\n",
    "\n",
    "        # save files\n",
    "        if (idx + 1) % save_frequency == 0:\n",
    "            print(f\"Saving file {n_file}\")\n",
    "            batch_df = pd.DataFrame.from_dict(\n",
    "                {\"caption\": all_captions, \"encoding\": all_encoding}\n",
    "            )\n",
    "            batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")\n",
    "            all_captions = []\n",
    "            all_encoding = []\n",
    "            n_file += 1\n",
    "\n",
    "    if len(all_captions):\n",
    "        print(f\"Saving final file {n_file}\")\n",
    "        batch_df = pd.DataFrame.from_dict(\n",
    "            {\"caption\": all_captions, \"encoding\": all_encoding}\n",
    "        )\n",
    "        batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7704863d",
   "metadata": {},
   "outputs": [],
   "source": [
    "encode_dataset(dl, output_dir=encoded_output, save_frequency=save_frequency)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8953dd84",
   "metadata": {},
   "source": [
    "----"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}


================================================
FILE: tools/inference/inference_pipeline.ipynb
================================================
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "118UKH5bWCGa"
      },
      "source": [
        "# DALL·E mini - Inference pipeline\n",
        "\n",
        "*Generate images from a text prompt*\n",
        "\n",
        "<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
        "\n",
        "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
        "\n",
        "Just want to play? Use directly [the app](https://www.craiyon.com/).\n",
        "\n",
        "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dS8LbaonYm3a"
      },
      "source": [
        "## 🛠️ Installation and set-up"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uzjAM2GBYpZX"
      },
      "outputs": [],
      "source": [
        "# Required only for colab environments + GPU\n",
        "!pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
        "\n",
        "# Install required libraries\n",
        "!pip install -q dalle-mini\n",
        "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ozHzTkyv8cqU"
      },
      "source": [
        "We load required models:\n",
        "* DALL·E mini for text to encoded images\n",
        "* VQGAN for decoding images\n",
        "* CLIP for scoring predictions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K6CxW2o42f-w"
      },
      "outputs": [],
      "source": [
        "# Model references\n",
        "\n",
        "# dalle-mega\n",
        "DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\"  # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
        "DALLE_COMMIT_ID = None\n",
        "\n",
        "# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n",
        "# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n",
        "\n",
        "# VQGAN model\n",
        "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
        "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Yv-aR3t4Oe5v"
      },
      "outputs": [],
      "source": [
        "import jax\n",
        "import jax.numpy as jnp\n",
        "\n",
        "# check how many devices are available\n",
        "jax.local_device_count()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "92zYmvsQ38vL"
      },
      "outputs": [],
      "source": [
        "# Load models & tokenizer\n",
        "from dalle_mini import DalleBart, DalleBartProcessor\n",
        "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
        "from transformers import CLIPProcessor, FlaxCLIPModel\n",
        "\n",
        "# Load dalle-mini\n",
        "model, params = DalleBart.from_pretrained(\n",
        "    DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
        ")\n",
        "\n",
        "# Load VQGAN\n",
        "vqgan, vqgan_params = VQModel.from_pretrained(\n",
        "    VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o_vH2X1tDtzA"
      },
      "source": [
        "Model parameters are replicated on each device for faster inference."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wtvLoM48EeVw"
      },
      "outputs": [],
      "source": [
        "from flax.jax_utils import replicate\n",
        "\n",
        "params = replicate(params)\n",
        "vqgan_params = replicate(vqgan_params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0A9AHQIgZ_qw"
      },
      "source": [
        "Model functions are compiled and parallelized to take advantage of multiple devices."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sOtoOmYsSYPz"
      },
      "outputs": [],
      "source": [
        "from functools import partial\n",
        "\n",
        "\n",
        "# model inference\n",
        "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
        "def p_generate(\n",
        "    tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
        "):\n",
        "    return model.generate(\n",
        "        **tokenized_prompt,\n",
        "        prng_key=key,\n",
        "        params=params,\n",
        "        top_k=top_k,\n",
        "        top_p=top_p,\n",
        "        temperature=temperature,\n",
        "        condition_scale=condition_scale,\n",
        "    )\n",
        "\n",
        "\n",
        "# decode image\n",
        "@partial(jax.pmap, axis_name=\"batch\")\n",
        "def p_decode(indices, params):\n",
        "    return vqgan.decode_code(indices, params=params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HmVN6IBwapBA"
      },
      "source": [
        "Keys are passed to the model on each device to generate unique inference per device."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4CTXmlUkThhX"
      },
      "outputs": [],
      "source": [
        "import random\n",
        "\n",
        "# create a random key\n",
        "seed = random.randint(0, 2**32 - 1)\n",
        "key = jax.random.PRNGKey(seed)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BrnVyCo81pij"
      },
      "source": [
        "## 🖍 Text Prompt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rsmj0Aj5OQox"
      },
      "source": [
        "Our model requires processing prompts."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YjjhUychOVxm"
      },
      "outputs": [],
      "source": [
        "from dalle_mini import DalleBartProcessor\n",
        "\n",
        "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BQ7fymSPyvF_"
      },
      "source": [
        "Let's define some text prompts."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x_0vI9ge1oKr"
      },
      "outputs": [],
      "source": [
        "prompts = [\n",
        "    \"sunset over a lake in the mountains\",\n",
        "    \"the Eiffel tower landing on the moon\",\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XlZUG3SCLnGE"
      },
      "source": [
        "Note: we could use the same prompt multiple times for faster inference."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VKjEZGjtO49k"
      },
      "outputs": [],
      "source": [
        "tokenized_prompts = processor(prompts)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-CEJBnuJOe5z"
      },
      "source": [
        "Finally we replicate the prompts onto each device."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lQePgju5Oe5z"
      },
      "outputs": [],
      "source": [
        "tokenized_prompt = replicate(tokenized_prompts)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "phQ9bhjRkgAZ"
      },
      "source": [
        "## 🎨 Generate images\n",
        "\n",
        "We generate images using dalle-mini model and decode them with the VQGAN."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d0wVkXpKqnHA"
      },
      "outputs": [],
      "source": [
        "# number of predictions per prompt\n",
        "n_predictions = 8\n",
        "\n",
        "# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)\n",
        "gen_top_k = None\n",
        "gen_top_p = None\n",
        "temperature = None\n",
        "cond_scale = 10.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SDjEx9JxR3v8"
      },
      "outputs": [],
      "source": [
        "from flax.training.common_utils import shard_prng_key\n",
        "import numpy as np\n",
        "from PIL import Image\n",
        "from tqdm.notebook import trange\n",
        "\n",
        "print(f\"Prompts: {prompts}\\n\")\n",
        "# generate images\n",
        "images = []\n",
        "for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
        "    # get a new key\n",
        "    key, subkey = jax.random.split(key)\n",
        "    # generate images\n",
        "    encoded_images = p_generate(\n",
        "        tokenized_prompt,\n",
        "        shard_prng_key(subkey),\n",
        "        params,\n",
        "        gen_top_k,\n",
        "        gen_top_p,\n",
        "        temperature,\n",
        "        cond_scale,\n",
        "    )\n",
        "    # remove BOS\n",
        "    encoded_images = encoded_images.sequences[..., 1:]\n",
        "    # decode images\n",
        "    decoded_images = p_decode(encoded_images, vqgan_params)\n",
        "    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
        "    for decoded_img in decoded_images:\n",
        "        img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n",
        "        images.append(img)\n",
        "        display(img)\n",
        "        print()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tw02wG9zGmyB"
      },
      "source": [
        "## 🏅 Optional: Rank images by CLIP score\n",
        "\n",
        "We can rank images according to CLIP.\n",
        "\n",
        "**Note: your session may crash if you don't have a subscription to Colab Pro.**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RGjlIW_f6GA0"
      },
      "outputs": [],
      "source": [
        "# CLIP model\n",
        "CLIP_REPO = \"openai/clip-vit-base-patch32\"\n",
        "CLIP_COMMIT_ID = None\n",
        "\n",
        "# Load CLIP\n",
        "clip, clip_params = FlaxCLIPModel.from_pretrained(\n",
        "    CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
        ")\n",
        "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
        "clip_params = replicate(clip_params)\n",
        "\n",
        "\n",
        "# score images\n",
        "@partial(jax.pmap, axis_name=\"batch\")\n",
        "def p_clip(inputs, params):\n",
        "    logits = clip(params=params, **inputs).logits_per_image\n",
        "    return logits"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FoLXpjCmGpju"
      },
      "outputs": [],
      "source": [
        "from flax.training.common_utils import shard\n",
        "\n",
        "# get clip scores\n",
        "clip_inputs = clip_processor(\n",
        "    text=prompts * jax.device_count(),\n",
        "    images=images,\n",
        "    return_tensors=\"np\",\n",
        "    padding=\"max_length\",\n",
        "    max_length=77,\n",
        "    truncation=True,\n",
        ").data\n",
        "logits = p_clip(shard(clip_inputs), clip_params)\n",
        "\n",
        "# organize scores per prompt\n",
        "p = len(prompts)\n",
        "logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4AAWRm70LgED"
      },
      "source": [
        "Let's now display images ranked by CLIP score."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zsgxxubLLkIu"
      },
      "outputs": [],
      "source": [
        "for i, prompt in enumerate(prompts):\n",
        "    print(f\"Prompt: {prompt}\\n\")\n",
        "    for idx in logits[i].argsort()[::-1]:\n",
        "        display(images[idx * p + i])\n",
        "        print(f\"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\\n\")\n",
        "    print()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oZT9i3jCjir0"
      },
      "source": [
        "## 🪄 Optional: Save your Generated Images as W&B Tables\n",
        "\n",
        "W&B Tables is an interactive 2D grid with support to rich media logging. Use this to save the generated images on W&B dashboard and share with the world."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-pSiv6Vwjkn0"
      },
      "outputs": [],
      "source": [
        "import wandb\n",
        "\n",
        "# Initialize a W&B run.\n",
        "project = \"dalle-mini-tables-colab\"\n",
        "run = wandb.init(project=project)\n",
        "\n",
        "# Initialize an empty W&B Tables.\n",
        "columns = [\"captions\"] + [f\"image_{i+1}\" for i in range(n_predictions)]\n",
        "gen_table = wandb.Table(columns=columns)\n",
        "\n",
        "# Add data to the table.\n",
        "for i, prompt in enumerate(prompts):\n",
        "    # If CLIP scores exist, sort the Images\n",
        "    if logits is not None:\n",
        "        idxs = logits[i].argsort()[::-1]\n",
        "        tmp_imgs = images[i :: len(prompts)]\n",
        "        tmp_imgs = [tmp_imgs[idx] for idx in idxs]\n",
        "    else:\n",
        "        tmp_imgs = images[i :: len(prompts)]\n",
        "\n",
        "    # Add the data to the table.\n",
        "    gen_table.add_data(prompt, *[wandb.Image(img) for img in tmp_imgs])\n",
        "\n",
        "# Log the Table to W&B dashboard.\n",
        "wandb.log({\"Generated Images\": gen_table})\n",
        "\n",
        "# Close the W&B run.\n",
        "run.finish()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ck2ZnHwVjnRd"
      },
      "source": [
        "Click on the link above to check out your generated images."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "machine_shape": "hm",
      "name": "DALL·E mini - Inference pipeline.ipynb",
      "provenance": [],
      "gpuType": "A100",
      "include_colab_link": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}

================================================
FILE: tools/inference/run_infer_notebook.sh
================================================
#!/bin/bash
jupyter notebook --ip 0.0.0.0 --no-browser --allow-root

================================================
FILE: tools/train/config/mega/config.json
================================================
{
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "attention_dropout": 0.0,
  "bos_token_id": 16385,
  "d_model": 2048,
  "decoder_attention_heads": 32,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 24,
  "decoder_start_token_id": 16384,
  "do_sample": true,
  "dropout": 0.0,
  "encoder_attention_heads": 32,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 24,
  "encoder_vocab_size": 50272,
  "eos_token_id": 16385,
  "force_ln_scale": false,
  "gradient_checkpointing": false,
  "image_length": 256,
  "image_vocab_size": 16415,
  "init_std": 0.01,
  "is_encoder_decoder": true,
  "ln_positions": "normformer",
  "ln_type": "layernorm",
  "max_length": 257,
  "max_text_length": 64,
  "min_length": 257,
  "model_type": "dallebart",
  "normalize_text": true,
  "pad_token_id": 16385,
  "scale_embedding": false,
  "sinkhorn_iters": 1,
  "tau_init": 0.05,
  "tie_word_embeddings": false,
  "use_absolute_position_embeddings": true,
  "use_alibi": false,
  "use_bias": false,
  "use_cache": true,
  "use_cosine_attention": false,
  "use_deepnet_scaling": false,
  "use_final_ln_decoder": true,
  "use_final_ln_encoder": true,
  "use_glu": true,
  "use_head_scale": false,
  "use_swin_position_embeddings": false
}


================================================
FILE: tools/train/config/micro/config.json
================================================
{
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "attention_dropout": 0.0,
  "bos_token_id": 16385,
  "d_model": 256,
  "decoder_attention_heads": 2,
  "decoder_ffn_dim": 256,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 2,
  "decoder_start_token_id": 16384,
  "dropout": 0.0,
  "encoder_attention_heads": 2,
  "encoder_ffn_dim": 256,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 2,
  "encoder_vocab_size": 50264,
  "eos_token_id": 16385,
  "image_length": 256,
  "image_vocab_size": 16391,
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "max_text_length": 64,
  "model_type": "dallebart",
  "normalize_text": true,
  "pad_token_id": 16385,
  "scale_embedding": false,
  "tie_word_embeddings": false,
  "use_cache": true
}


================================================
FILE: tools/train/config/mini/config.json
================================================
{
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "attention_dropout": 0.0,
  "bos_token_id": 16385,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layers": 12,
  "decoder_start_token_id": 16384,
  "dropout": 0.0,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layers": 12,
  "encoder_vocab_size": 50264,
  "eos_token_id": 16385,
  "gradient_checkpointing": false,
  "image_length": 256,
  "image_vocab_size": 16391,
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "max_text_length": 64,
  "model_type": "dallebart",
  "normalize_text": true,
  "pad_token_id": 16385,
  "scale_embedding": false,
  "tie_word_embeddings": false,
  "use_cache": true
}


================================================
FILE: tools/train/config/mini_glu/config.json
================================================
{
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "attention_dropout": 0.0,
  "bos_token_id": 16385,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 2730,
  "decoder_layers": 12,
  "decoder_start_token_id": 16384,
  "dropout": 0.0,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 2730,
  "encoder_layers": 12,
  "encoder_vocab_size": 50300,
  "eos_token_id": 16385,
  "gradient_checkpointing": false,
  "image_length": 256,
  "image_vocab_size": 16400,
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "max_text_length": 64,
  "model_type": "dallebart",
  "normalize_text": true,
  "pad_token_id": 16385,
  "scale_embedding": false,
  "tie_word_embeddings": false,
  "use_scan": false,
  "use_cache": true
}


================================================
FILE: tools/train/embeddings_retrain_preparation.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "118UKH5bWCGa"
   },
   "source": [
    "# DALL·E mini - Embedding Retrain Preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll start with the dalle-mini model for faster experimentation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "K6CxW2o42f-w"
   },
   "outputs": [],
   "source": [
    "DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"  # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
    "DALLE_COMMIT_ID = None\n",
    "\n",
    "# # dalle-mega\n",
    "# DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\"  # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
    "# DALLE_COMMIT_ID = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "Yv-aR3t4Oe5v"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "# check how many devices are available\n",
    "jax.local_device_count()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We load the model twice to keep a copy of the original parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "92zYmvsQ38vL"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact mini-1:v0, 1673.43MB. 7 files... Done. 0:0:1.2\n",
      "tcmalloc: large alloc 1751343104 bytes == 0x56011a2c0000 @  0x7f143aaa9680 0x7f143aaca824 0x5600d248253b 0x5600d24c30ba 0x5600d2599a58 0x5600d24f548d 0x5600d23cf328 0x5600d25af66d 0x5600d24f5825 0x5600d24532da 0x5600d24eafe3 0x5600d24ec709 0x5600d249a1ea 0x5600d252be7a 0x5600d24eafe3 0x5600d24ec709 0x5600d245273d 0x5600d24eafe3 0x5600d2597a7c 0x5600d24ebdbb 0x5600d25ce33e 0x5600d24f5571 0x5600d2452088 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d24f5f94 0x5600d24532da 0x5600d24ebbe4\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact mini-1:v0, 1673.43MB. 7 files... Done. 0:0:1.2\n",
      "tcmalloc: large alloc 1751343104 bytes == 0x56011a2c0000 @  0x7f143aaa9680 0x7f143aaca824 0x5600d248253b 0x5600d24c30ba 0x5600d2599a58 0x5600d24f548d 0x5600d23cf328 0x5600d25af66d 0x5600d24f5825 0x5600d24532da 0x5600d24eafe3 0x5600d24ec709 0x5600d249a1ea 0x5600d252be7a 0x5600d24eafe3 0x5600d24ec709 0x5600d245273d 0x5600d24eafe3 0x5600d2597a7c 0x5600d24ebdbb 0x5600d25ce33e 0x5600d24f5571 0x5600d2452088 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d24f5f94 0x5600d24532da 0x5600d24ebbe4\n"
     ]
    }
   ],
   "source": [
    "# Load model\n",
    "from dalle_mini import DalleBart, DalleBartProcessor\n",
    "\n",
    "model, params = DalleBart.from_pretrained(\n",
    "    DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
    ")\n",
    "\n",
    "_, params_original = DalleBart.from_pretrained(\n",
    "    DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model surgery: remove layers to be retrained"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a look at the params tree."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "437833712"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(x.size for x in jax.tree_leaves(params))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"lm_head\": {\n",
      "    \"kernel\": [\n",
      "      1024,\n",
      "      16385\n",
      "    ]\n",
      "  },\n",
      "  \"model\": {\n",
      "    \"decoder\": {\n",
      "      \"embed_positions\": {\n",
      "        \"embedding\": [\n",
      "          256,\n",
      "          1024\n",
      "        ]\n",
      "      },\n",
      "      \"embed_tokens\": {\n",
      "        \"embedding\": [\n",
      "          16385,\n",
      "          1024\n",
      "        ]\n",
      "      },\n",
      "      \"final_ln\": {\n",
      "        \"bias\": [\n",
      "          1024\n",
      "        ]\n",
      "      },\n",
      "      \"layernorm_embedding\": {\n",
      "        \"bias\": [\n",
      "          1024\n",
      "        ],\n",
      "        \"scale\": [\n",
      "          1024\n",
      "        ]\n",
      "      },\n",
      "      \"layers\": {\n",
      "        \"FlaxBartDecoderLayers\": {\n",
      "          \"FlaxBartAttention_0\": {\n",
      "            \"k_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"out_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"q_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"v_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            }\n",
      "          },\n",
      "          \"FlaxBartAttention_1\": {\n",
      "            \"k_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"out_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"q_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"v_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            }\n",
      "          },\n",
      "          \"GLU_0\": {\n",
      "            \"Dense_0\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                2730\n",
      "              ]\n",
      "            },\n",
      "            \"Dense_1\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                2730\n",
      "              ]\n",
      "            },\n",
      "            \"Dense_2\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                2730,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"LayerNorm_0\": {\n",
      "              \"bias\": [\n",
      "                12,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"LayerNorm_1\": {\n",
      "              \"bias\": [\n",
      "                12,\n",
      "                2730\n",
      "              ]\n",
      "            }\n",
      "          },\n",
      "          \"LayerNorm_0\": {\n",
      "            \"bias\": [\n",
      "              12,\n",
      "              1024\n",
      "            ]\n",
      "          },\n",
      "          \"LayerNorm_1\": {\n",
      "            \"bias\": [\n",
      "              12,\n",
      "              1024\n",
      "            ],\n",
      "            \"scale\": [\n",
      "              12,\n",
      "              1024\n",
      "            ]\n",
      "          },\n",
      "          \"LayerNorm_2\": {\n",
      "            \"bias\": [\n",
      "              12,\n",
      "              1024\n",
      "            ]\n",
      "          },\n",
      "          \"LayerNorm_3\": {\n",
      "            \"bias\": [\n",
      "              12,\n",
      "              1024\n",
      "            ],\n",
      "            \"scale\": [\n",
      "              12,\n",
      "              1024\n",
      "            ]\n",
      "          }\n",
      "        }\n",
      "      }\n",
      "    },\n",
      "    \"encoder\": {\n",
      "      \"embed_positions\": {\n",
      "        \"embedding\": [\n",
      "          64,\n",
      "          1024\n",
      "        ]\n",
      "      },\n",
      "      \"embed_tokens\": {\n",
      "        \"embedding\": [\n",
      "          50264,\n",
      "          1024\n",
      "        ]\n",
      "      },\n",
      "      \"final_ln\": {\n",
      "        \"bias\": [\n",
      "          1024\n",
      "        ]\n",
      "      },\n",
      "      \"layernorm_embedding\": {\n",
      "        \"bias\": [\n",
      "          1024\n",
      "        ],\n",
      "        \"scale\": [\n",
      "          1024\n",
      "        ]\n",
      "      },\n",
      "      \"layers\": {\n",
      "        \"FlaxBartEncoderLayers\": {\n",
      "          \"FlaxBartAttention_0\": {\n",
      "            \"k_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"out_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"q_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"v_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            }\n",
      "          },\n",
      "          \"GLU_0\": {\n",
      "            \"Dense_0\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                2730\n",
      "              ]\n",
      "            },\n",
      "            \"Dense_1\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                2730\n",
      "              ]\n",
      "            },\n",
      "            \"Dense_2\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                2730,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"LayerNorm_0\": {\n",
      "              \"bias\": [\n",
      "                12,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"LayerNorm_1\": {\n",
      "              \"bias\": [\n",
      "                12,\n",
      "                2730\n",
      "              ]\n",
      "            }\n",
      "          },\n",
      "          \"LayerNorm_0\": {\n",
      "            \"bias\": [\n",
      "              12,\n",
      "              1024\n",
      "            ]\n",
      "          },\n",
      "          \"LayerNorm_1\": {\n",
      "            \"bias\": [\n",
      "              12,\n",
      "              1024\n",
      "            ],\n",
      "            \"scale\": [\n",
      "              12,\n",
      "              1024\n",
      "            ]\n",
      "          }\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "tree = jax.tree_map(lambda x: x.shape, params)\n",
    "print(json.dumps(tree, indent=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will remove or reinitialize:\n",
    "- `lm_head`\n",
    "- `model.decoder.embed_positions`\n",
    "- `model.decoder.embed_tokens`\n",
    "- `model.decoder.final_ln`\n",
    "- `model.decoder.layernorm_embedding`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "del params[\"lm_head\"]\n",
    "for layer in [\"embed_positions\", \"embed_tokens\", \"final_ln\", \"layernorm_embedding\"]:\n",
    "    del params[\"model\"][\"decoder\"][layer]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'model': {'decoder': {'layers': {'FlaxBartDecoderLayers': {'FlaxBartAttention_0': {'k_proj': {'kernel': (12,\n",
       "        1024,\n",
       "        1024)},\n",
       "      'out_proj': {'kernel': (12, 1024, 1024)},\n",
       "      'q_proj': {'kernel': (12, 1024, 1024)},\n",
       "      'v_proj': {'kernel': (12, 1024, 1024)}},\n",
       "     'FlaxBartAttention_1': {'k_proj': {'kernel': (12, 1024, 1024)},\n",
       "      'out_proj': {'kernel': (12, 1024, 1024)},\n",
       "      'q_proj': {'kernel': (12, 1024, 1024)},\n",
       "      'v_proj': {'kernel': (12, 1024, 1024)}},\n",
       "     'GLU_0': {'Dense_0': {'kernel': (12, 1024, 2730)},\n",
       "      'Dense_1': {'kernel': (12, 1024, 2730)},\n",
       "      'Dense_2': {'kernel': (12, 2730, 1024)},\n",
       "      'LayerNorm_0': {'bias': (12, 1024)},\n",
       "      'LayerNorm_1': {'bias': (12, 2730)}},\n",
       "     'LayerNorm_0': {'bias': (12, 1024)},\n",
       "     'LayerNorm_1': {'bias': (12, 1024), 'scale': (12, 1024)},\n",
       "     'LayerNorm_2': {'bias': (12, 1024)},\n",
       "     'LayerNorm_3': {'bias': (12, 1024), 'scale': (12, 1024)}}}},\n",
       "  'encoder': {'embed_positions': {'embedding': (64, 1024)},\n",
       "   'embed_tokens': {'embedding': (50264, 1024)},\n",
       "   'final_ln': {'bias': (1024,)},\n",
       "   'layernorm_embedding': {'bias': (1024,), 'scale': (1024,)},\n",
       "   'layers': {'FlaxBartEncoderLayers': {'FlaxBartAttention_0': {'k_proj': {'kernel': (12,\n",
       "        1024,\n",
       "        1024)},\n",
       "      'out_proj': {'kernel': (12, 1024, 1024)},\n",
       "      'q_proj': {'kernel': (12, 1024, 1024)},\n",
       "      'v_proj': {'kernel': (12, 1024, 1024)}},\n",
       "     'GLU_0': {'Dense_0': {'kernel': (12, 1024, 2730)},\n",
       "      'Dense_1': {'kernel': (12, 1024, 2730)},\n",
       "      'Dense_2': {'kernel': (12, 2730, 1024)},\n",
       "      'LayerNorm_0': {'bias': (12, 1024)},\n",
       "      'LayerNorm_1': {'bias': (12, 2730)}},\n",
       "     'LayerNorm_0': {'bias': (12, 1024)},\n",
       "     'LayerNorm_1': {'bias': (12, 1024), 'scale': (12, 1024)}}}}}}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.tree_map(lambda x: x.shape, params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "404012016"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(x.size for x in jax.tree_leaves(params))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reinitialize layers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We save a checkpoint and reload it again. It does not automatically reinitialize the missing keys, but it sets `_missing_keys` appropriately so we can initialize them later. We could do the same by simply setting that property ourselves, but I'll refrain from doing so because it's a private implementation detail."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "trimmed_checkpoint = \"mini-trimmed\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "tcmalloc: large alloc 1610424320 bytes == 0x5632d11c8000 @  0x7f95ccad0680 0x7f95ccaf0bdd 0x7f95be99e29f 0x7f95be9a7750 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a4fc4 0x7f95be9a571e 0x5630fb630f94 0x5630fb58e2da 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb626be4 0x5630fb58d088 0x5630fb625fe3 0x5630fb627709 0x5630fb58d73d 0x5630fb625fe3 0x5630fb6d2a7c 0x5630fb626dbb 0x5630fb70933e 0x5630fb630571\n",
      "tcmalloc: large alloc 3231449088 bytes == 0x56333119a000 @  0x7f95ccad0680 0x7f95ccaf0bdd 0x7f95be99e29f 0x7f95be9a7750 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a4fc4 0x7f95be9a571e 0x5630fb630f94 0x5630fb58e2da 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb626be4 0x5630fb58d088 0x5630fb625fe3 0x5630fb627709 0x5630fb58d73d 0x5630fb625fe3 0x5630fb6d2a7c 0x5630fb626dbb 0x5630fb70933e 0x5630fb630571\n"
     ]
    }
   ],
   "source": [
    "model.save_pretrained(trimmed_checkpoint, params=params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The checkpoint mini-trimmed is missing required keys: {('model', 'decoder', 'embed_tokens', 'embedding'), ('lm_head', 'kernel'), ('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'final_ln', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'scale'), ('model', 'decoder', 'layernorm_embedding', 'bias')}. Make sure to call model.init_weights to initialize the missing weights.\n",
      "Some weights of DalleBart were
Download .txt
gitextract_0gdudwyh/

├── .gitattributes
├── .github/
│   ├── FUNDING.yml
│   └── workflows/
│       ├── check_size.yml
│       ├── pypi_release.yml
│       ├── style.yml
│       ├── sync_to_hub.yml.backup
│       └── sync_to_hub_debug.yml
├── .gitignore
├── CITATION.cff
├── Docker/
│   ├── Dockerfile
│   ├── README.md
│   └── build_docker.sh
├── LICENSE
├── Makefile
├── README.md
├── app/
│   ├── gradio/
│   │   ├── app.py
│   │   └── backend.py
│   └── streamlit/
│       ├── app.py
│       └── backend.py
├── pyproject.toml
├── run_docker_image.sh
├── setup.cfg
├── setup.py
├── src/
│   └── dalle_mini/
│       ├── __init__.py
│       ├── data.py
│       └── model/
│           ├── __init__.py
│           ├── configuration.py
│           ├── modeling.py
│           ├── partitions.py
│           ├── processor.py
│           ├── text.py
│           ├── tokenizer.py
│           └── utils.py
└── tools/
    ├── dataset/
    │   └── encode_dataset.ipynb
    ├── inference/
    │   ├── inference_pipeline.ipynb
    │   └── run_infer_notebook.sh
    └── train/
        ├── config/
        │   ├── mega/
        │   │   └── config.json
        │   ├── micro/
        │   │   └── config.json
        │   ├── mini/
        │   │   └── config.json
        │   └── mini_glu/
        │       └── config.json
        ├── embeddings_retrain_preparation.ipynb
        ├── scalable_shampoo/
        │   ├── README.md
        │   ├── distributed_shampoo.py
        │   ├── quantization_utils.py
        │   ├── sm3.py
        │   └── symmetric_matrices/
        │       └── symmetric_matrices.py
        ├── sweep.yaml
        └── train.py
Download .txt
SYMBOL INDEX (182 symbols across 16 files)

FILE: app/gradio/app.py
  function infer (line 12) | def infer(prompt):

FILE: app/gradio/backend.py
  class ServiceError (line 10) | class ServiceError(Exception):
    method __init__ (line 11) | def __init__(self, status_code):
  function get_images_from_backend (line 15) | def get_images_from_backend(prompt, backend_url):
  function get_model_version (line 27) | def get_model_version(url):

FILE: app/streamlit/backend.py
  class ServiceError (line 10) | class ServiceError(Exception):
    method __init__ (line 11) | def __init__(self, status_code):
  function get_images_from_backend (line 15) | def get_images_from_backend(prompt, backend_url):
  function get_model_version (line 27) | def get_model_version(url):

FILE: src/dalle_mini/data.py
  class Dataset (line 16) | class Dataset:
    method __post_init__ (line 45) | def __post_init__(self):
    method preprocess (line 129) | def preprocess(self, tokenizer, config):
    method dataloader (line 303) | def dataloader(self, split, batch_size, epoch=None):
    method length (line 370) | def length(self):
  function shift_tokens_right (line 388) | def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
  function blank_caption_function (line 398) | def blank_caption_function(example, text_column, blank_caption_prob, rng...
  function normalize_function (line 408) | def normalize_function(example, text_column, text_normalizer):
  function filter_function (line 413) | def filter_function(
  function preprocess_function (line 430) | def preprocess_function(

FILE: src/dalle_mini/model/configuration.py
  class DalleBartConfig (line 26) | class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
    method __init__ (line 34) | def __init__(

FILE: src/dalle_mini/model/modeling.py
  function smelu (line 56) | def smelu(beta: Any = 1.0):
  function deepnet_init (line 82) | def deepnet_init(init_std, gain=1):
  class RMSNorm (line 117) | class RMSNorm(nn.Module):
    method __call__ (line 131) | def __call__(self, x):
    method _compute_rms_sq (line 150) | def _compute_rms_sq(self, x, axes):
    method _normalize (line 155) | def _normalize(
  function norm (line 189) | def norm(type, *args, **kwargs):
  function dot_product_attention_weights (line 198) | def dot_product_attention_weights(
  class FlaxBartAttention (line 276) | class FlaxBartAttention(FlaxBartAttention):
    method setup (line 288) | def setup(self) -> None:
    method __call__ (line 361) | def __call__(
  class GLU (line 491) | class GLU(nn.Module):
    method __call__ (line 501) | def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp....
  class FFN (line 559) | class FFN(nn.Module):
    method __call__ (line 569) | def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp....
  class FlaxBartEncoderLayer (line 616) | class FlaxBartEncoderLayer(nn.Module):
    method __call__ (line 629) | def __call__(
  class FlaxBartDecoderLayer (line 720) | class FlaxBartDecoderLayer(nn.Module):
    method __call__ (line 733) | def __call__(
  class FlaxBartEncoderLayerCollection (line 875) | class FlaxBartEncoderLayerCollection(nn.Module):
    method __call__ (line 885) | def __call__(
  class FlaxBartDecoderLayerCollection (line 981) | class FlaxBartDecoderLayerCollection(nn.Module):
    method __call__ (line 991) | def __call__(
  class FlaxBartEncoder (line 1114) | class FlaxBartEncoder(nn.Module):
    method setup (line 1126) | def setup(self):
    method __call__ (line 1158) | def __call__(
  class FlaxBartDecoder (line 1204) | class FlaxBartDecoder(nn.Module):
    method setup (line 1216) | def setup(self):
    method __call__ (line 1249) | def __call__(
  class FlaxBartModule (line 1302) | class FlaxBartModule(FlaxBartModule):
    method setup (line 1309) | def setup(self):
  class FlaxBartForConditionalGenerationModule (line 1329) | class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGener...
    method setup (line 1337) | def setup(self):
    method __call__ (line 1347) | def __call__(
  class SampleState (line 1399) | class SampleState:
  class FlaxSampleOutput (line 1410) | class FlaxSampleOutput(ModelOutput):
  class DalleBart (line 1423) | class DalleBart(PretrainedFromWandbMixin, FlaxBartForConditionalGenerati...
    method num_params (line 1439) | def num_params(self, params=None):
    method unscan (line 1447) | def unscan(self, params):
    method decode (line 1466) | def decode(
    method prepare_inputs_for_generation (line 1598) | def prepare_inputs_for_generation(
    method generate (line 1633) | def generate(
    method _sample (line 1811) | def _sample(

FILE: src/dalle_mini/model/partitions.py
  function _match (line 15) | def _match(qs, ks):
  function _replacement_rules (line 26) | def _replacement_rules(rules):
  function _get_partition_rules (line 36) | def _get_partition_rules():
  function set_partitions (line 58) | def set_partitions(in_dict, use_scan):

FILE: src/dalle_mini/model/processor.py
  class DalleBartProcessorBase (line 13) | class DalleBartProcessorBase:
    method __init__ (line 14) | def __init__(
    method __call__ (line 33) | def __call__(self, text: List[str] = None):
    method from_pretrained (line 53) | def from_pretrained(cls, *args, **kwargs):
  class DalleBartProcessor (line 59) | class DalleBartProcessor(PretrainedFromWandbMixin, DalleBartProcessorBase):

FILE: src/dalle_mini/model/text.py
  class HashtagProcessor (line 21) | class HashtagProcessor:
    method __init__ (line 24) | def __init__(self):
    method __call__ (line 38) | def __call__(self, s):
    method _split (line 43) | def _split(self, s):
  function replace_person_token (line 86) | def replace_person_token(t):
  function fix_html (line 96) | def fix_html(t):
  function replace_punctuation_with_commas (line 101) | def replace_punctuation_with_commas(t):
  function simplify_quotes (line 105) | def simplify_quotes(t):
  function merge_quotes (line 109) | def merge_quotes(t):
  function remove_comma_numbers (line 113) | def remove_comma_numbers(t):
  function pre_process_dot_numbers (line 120) | def pre_process_dot_numbers(t):
  function post_process_dot_numbers (line 124) | def post_process_dot_numbers(t):
  function pre_process_quotes (line 128) | def pre_process_quotes(t):
  function post_process_quotes (line 135) | def post_process_quotes(t):
  function pre_process_dates (line 139) | def pre_process_dates(t):
  function post_process_dates (line 143) | def post_process_dates(t):
  function merge_commas (line 147) | def merge_commas(t):
  function add_space_after_commas (line 151) | def add_space_after_commas(t):
  function handle_special_chars (line 155) | def handle_special_chars(t):
  function expand_hashtags (line 163) | def expand_hashtags(t, hashtag_processor):
  function ignore_chars (line 171) | def ignore_chars(t):
  function remove_extra_spaces (line 176) | def remove_extra_spaces(t):
  function remove_repeating_chars (line 181) | def remove_repeating_chars(t):
  function remove_urls (line 186) | def remove_urls(t):
  function remove_html_tags (line 190) | def remove_html_tags(t):
  function remove_first_last_commas (line 194) | def remove_first_last_commas(t):
  function remove_wiki_ref (line 201) | def remove_wiki_ref(t):
  class TextNormalizer (line 206) | class TextNormalizer:
    method __init__ (line 209) | def __init__(self):
    method __call__ (line 212) | def __call__(self, t):

FILE: src/dalle_mini/model/tokenizer.py
  class DalleBartTokenizer (line 7) | class DalleBartTokenizer(PretrainedFromWandbMixin, BartTokenizerFast):

FILE: src/dalle_mini/model/utils.py
  class PretrainedFromWandbMixin (line 8) | class PretrainedFromWandbMixin:
    method from_pretrained (line 10) | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, *...

FILE: tools/train/scalable_shampoo/distributed_shampoo.py
  class TrainingMetrics (line 58) | class TrainingMetrics:
  class ParameterStats (line 64) | class ParameterStats(NamedTuple):
  class GlobalShardedParameterStats (line 80) | class GlobalShardedParameterStats:
  class LocalShardedParameterStats (line 89) | class LocalShardedParameterStats:
  function init_training_metrics (line 102) | def init_training_metrics(num_statistics):
  function init_training_metrics_shapes (line 111) | def init_training_metrics_shapes(num_statistics):
  function init_training_metrics_pspec (line 120) | def init_training_metrics_pspec():
  class ShardedShampooStats (line 124) | class ShardedShampooStats(NamedTuple):
  class ShampooState (line 131) | class ShampooState(NamedTuple):
  class InitFnState (line 136) | class InitFnState(NamedTuple):
  class GraftingType (line 142) | class GraftingType(enum.IntEnum):
  class PreconditionerType (line 151) | class PreconditionerType(enum.IntEnum):
  function power_iteration (line 159) | def power_iteration(
  function mat_power (line 218) | def mat_power(
  function _pth_root_difference (line 246) | def _pth_root_difference(w, alpha, beta, p):
  function matrix_inverse_pth_root (line 267) | def matrix_inverse_pth_root(
  function merge_small_dims (line 422) | def merge_small_dims(shape_to_merge, max_dim):
  function pad_square_matrix (line 453) | def pad_square_matrix(mat, max_size):
  function make_sliced_padding (line 484) | def make_sliced_padding(
  function pad_block_symmetric_matrix (line 530) | def pad_block_symmetric_matrix(
  function pad_vector (line 588) | def pad_vector(vec, max_size):
  function efficient_cond (line 607) | def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs):
  class BlockPartitioner (line 623) | class BlockPartitioner:
    method __init__ (line 626) | def __init__(self, param, block_size):
    method split_sizes (line 645) | def split_sizes(self):
    method partition (line 648) | def partition(self, tensor):
    method merge_partitions (line 660) | def merge_partitions(self, partitions):
  function gram_weighted_update (line 677) | def gram_weighted_update(old_stats, g, axis, w1, w2, precision=None):
  class Preconditioner (line 704) | class Preconditioner:
    method __init__ (line 707) | def __init__(
    method updated_statistics_from_grad (line 734) | def updated_statistics_from_grad(
    method should_precondition_dims (line 777) | def should_precondition_dims(self):
    method shapes_for_preconditioners (line 786) | def shapes_for_preconditioners(self):
    method exponent_for_preconditioner (line 799) | def exponent_for_preconditioner(self):
    method preconditioned_grad (line 805) | def preconditioned_grad(self, grad, preconditioners):
  function _convert_to_parameter_stats (line 841) | def _convert_to_parameter_stats(global_stats, local_stat, convert_statis...
  function _convert_from_parameter_stats (line 864) | def _convert_from_parameter_stats(parameter_stats, local_stats):
  function _add_error_into_local_stats (line 876) | def _add_error_into_local_stats(local_stats, errors, inverse_failure_thr...
  function batch (line 907) | def batch(x, num_devices):
  function unbatch (line 914) | def unbatch(batched_values):
  function distributed_shampoo (line 929) | def distributed_shampoo(

FILE: tools/train/scalable_shampoo/quantization_utils.py
  class QuantizedValue (line 27) | class QuantizedValue:
    method from_float_value (line 40) | def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=Fa...
    method quantize (line 58) | def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
    method to_float (line 107) | def to_float(self):

FILE: tools/train/scalable_shampoo/sm3.py
  class SM3State (line 37) | class SM3State(NamedTuple):
  class ParameterStats (line 43) | class ParameterStats(NamedTuple):
  function sm3 (line 50) | def sm3(

FILE: tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py
  class SlicedSymmetricMatrix (line 29) | class SlicedSymmetricMatrix:
  function product_with_transpose (line 42) | def product_with_transpose(
  function sliced_transposed_product (line 62) | def sliced_transposed_product(
  function sliced_transposed_product_concat (line 130) | def sliced_transposed_product_concat(
  function materialize_matrix (line 156) | def materialize_matrix(symmetric_matrix):
  function materialize_matrix_from_concat (line 193) | def materialize_matrix_from_concat(
  function update_sliced_rows (line 225) | def update_sliced_rows(
  function num_blocks_from_total_blocks (line 258) | def num_blocks_from_total_blocks(total_blocks):
  function find_num_blocks (line 280) | def find_num_blocks(block_rows_concat):
  function slice_symmetric_matrix (line 303) | def slice_symmetric_matrix(
  function slice_symmetric_matrix_concat (line 334) | def slice_symmetric_matrix_concat(
  function sliced_matrix_diag (line 348) | def sliced_matrix_diag(mat):
  function diag_as_concat (line 365) | def diag_as_concat(diag, block_size):
  function row_abs_maxes (line 382) | def row_abs_maxes(mat):
  function times_vector (line 420) | def times_vector(mat, vec):

FILE: tools/train/train.py
  class ModelArguments (line 73) | class ModelArguments:
    method __post_init__ (line 123) | def __post_init__(self):
    method get_metadata (line 134) | def get_metadata(self):
    method get_opt_state (line 144) | def get_opt_state(self):
  class DataTrainingArguments (line 178) | class DataTrainingArguments:
    method __post_init__ (line 292) | def __post_init__(self):
  class TrainingArguments (line 298) | class TrainingArguments:
    method __post_init__ (line 510) | def __post_init__(self):
  function split_params (line 569) | def split_params(data):
  function unsplit_params (line 587) | def unsplit_params(data):
  function trainable_params (line 595) | def trainable_params(data, embeddings_only):
  function init_embeddings (line 619) | def init_embeddings(model, params):
  function main (line 637) | def main():
Condensed preview — 48 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (425K chars).
[
  {
    "path": ".gitattributes",
    "chars": 690,
    "preview": "*.bin.* filter=lfs diff=lfs merge=lfs -text\n*.lfs.* filter=lfs diff=lfs merge=lfs -text\n*.bin filter=lfs diff=lfs merge="
  },
  {
    "path": ".github/FUNDING.yml",
    "chars": 21,
    "preview": "github: [borisdayma]\n"
  },
  {
    "path": ".github/workflows/check_size.yml",
    "chars": 362,
    "preview": "name: Check file size\n\non:\n  pull_request:\n    branches: [main]\n\n  # to run this workflow manually from the Actions tab\n"
  },
  {
    "path": ".github/workflows/pypi_release.yml",
    "chars": 847,
    "preview": "# This workflow uses actions that are not certified by GitHub.\n# They are provided by a third-party and are governed by\n"
  },
  {
    "path": ".github/workflows/style.yml",
    "chars": 397,
    "preview": "name: Lint\n\non:\n  push:\n    branches: [main]\n  pull_request:\n    branches: [main]\n\njobs:\n  lint:\n    runs-on: ubuntu-lat"
  },
  {
    "path": ".github/workflows/sync_to_hub.yml.backup",
    "chars": 489,
    "preview": "name: Sync to Hugging Face hub - Obsolete to avoid app disruptions\n\non:\n  push:\n    branches: [main]\n\n  # to run this wo"
  },
  {
    "path": ".github/workflows/sync_to_hub_debug.yml",
    "chars": 444,
    "preview": "name: Deploy to debug app\n\non:\n  # to run this workflow manually from the Actions tab\n  workflow_dispatch:\n\njobs:\n  sync"
  },
  {
    "path": ".gitignore",
    "chars": 72,
    "preview": "__pycache__\n.ipynb_checkpoints\n.streamlit\nwandb/\n*.egg-info/\njax_cache/\n"
  },
  {
    "path": "CITATION.cff",
    "chars": 1387,
    "preview": "# YAML 1.2\n---\nabstract: \"DALL·E mini is a JAX/Flax reimplementation of OpenAI's DALL·E that requires much smaller hardw"
  },
  {
    "path": "Docker/Dockerfile",
    "chars": 445,
    "preview": "FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04\n\nRUN apt-get update && apt-get install -y \\\n  git \\\n  python3 \\\n  pytho"
  },
  {
    "path": "Docker/README.md",
    "chars": 715,
    "preview": "# Running Dalle-mini With Docker\r\n\r\nThis folder contains the Dockerfile needed to build a Docker image that can easily r"
  },
  {
    "path": "Docker/build_docker.sh",
    "chars": 36,
    "preview": "docker build . -t dalle-mini:latest\n"
  },
  {
    "path": "LICENSE",
    "chars": 11353,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "Makefile",
    "chars": 39,
    "preview": ".PHONY: style\n\nstyle:\n\tblack .\n\tisort ."
  },
  {
    "path": "README.md",
    "chars": 10627,
    "preview": "# DALL·E Mini\n\n<a href=\"https://www.craiyon.com/\"><img src=\"https://www.craiyon.com/thumbnail.png\" width=\"300\"></a>\n\n## "
  },
  {
    "path": "app/gradio/app.py",
    "chars": 1854,
    "preview": "#!/usr/bin/env python\n# coding: utf-8\nimport os\n\nimport gradio as gr\nfrom backend import get_images_from_backend\n\nblock "
  },
  {
    "path": "app/gradio/backend.py",
    "chars": 863,
    "preview": "# Client requests to Dalle-Mini Backend server\n\nimport base64\nfrom io import BytesIO\n\nimport requests\nfrom PIL import Im"
  },
  {
    "path": "app/streamlit/app.py",
    "chars": 3571,
    "preview": "#!/usr/bin/env python\n# coding: utf-8\n\nimport streamlit as st\nfrom backend import ServiceError, get_images_from_backend\n"
  },
  {
    "path": "app/streamlit/backend.py",
    "chars": 863,
    "preview": "# Client requests to Dalle-Mini Backend server\n\nimport base64\nfrom io import BytesIO\n\nimport requests\nfrom PIL import Im"
  },
  {
    "path": "pyproject.toml",
    "chars": 30,
    "preview": "[tool.isort]\nprofile = \"black\""
  },
  {
    "path": "run_docker_image.sh",
    "chars": 240,
    "preview": "#!/bin/bash\n\n# This script is used to run the docker image. Change or remove GPU flag if you dont have nvidia-docker or "
  },
  {
    "path": "setup.cfg",
    "chars": 1079,
    "preview": "[metadata]\nname = dalle-mini\nversion = attr: dalle_mini.__version__\nauthor = Boris Dayma et al.\nauthor_email = boris.day"
  },
  {
    "path": "setup.py",
    "chars": 69,
    "preview": "from setuptools import setup\n\nif __name__ == \"__main__\":\n    setup()\n"
  },
  {
    "path": "src/dalle_mini/__init__.py",
    "chars": 72,
    "preview": "__version__ = \"0.1.5\"\n\nfrom .model import DalleBart, DalleBartProcessor\n"
  },
  {
    "path": "src/dalle_mini/data.py",
    "chars": 17784,
    "preview": "import random\nfrom dataclasses import dataclass, field\nfrom functools import partial\nfrom pathlib import Path\n\nimport ja"
  },
  {
    "path": "src/dalle_mini/model/__init__.py",
    "chars": 198,
    "preview": "from .configuration import DalleBartConfig\nfrom .modeling import DalleBart\nfrom .partitions import set_partitions\nfrom ."
  },
  {
    "path": "src/dalle_mini/model/configuration.py",
    "chars": 7947,
    "preview": "# coding=utf-8\n# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed und"
  },
  {
    "path": "src/dalle_mini/model/modeling.py",
    "chars": 70641,
    "preview": "# coding=utf-8\n# Copyright 2021-2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team "
  },
  {
    "path": "src/dalle_mini/model/partitions.py",
    "chars": 2506,
    "preview": "import re\n\nfrom flax.core.frozen_dict import freeze\nfrom flax.traverse_util import flatten_dict, unflatten_dict\nfrom jax"
  },
  {
    "path": "src/dalle_mini/model/processor.py",
    "chars": 2038,
    "preview": "\"\"\" DalleBart processor \"\"\"\n\nfrom typing import List\n\nimport jax.numpy as jnp\n\nfrom .configuration import DalleBartConfi"
  },
  {
    "path": "src/dalle_mini/model/text.py",
    "chars": 7392,
    "preview": "\"\"\"\nUtilities for processing text.\n\"\"\"\n\nimport html\nimport math\nimport random\nimport re\nfrom pathlib import Path\n\nimport"
  },
  {
    "path": "src/dalle_mini/model/tokenizer.py",
    "chars": 198,
    "preview": "\"\"\" DalleBart tokenizer \"\"\"\nfrom transformers import BartTokenizerFast\n\nfrom .utils import PretrainedFromWandbMixin\n\n\ncl"
  },
  {
    "path": "src/dalle_mini/model/utils.py",
    "chars": 1028,
    "preview": "import os\nimport tempfile\nfrom pathlib import Path\n\nimport wandb\n\n\nclass PretrainedFromWandbMixin:\n    @classmethod\n    "
  },
  {
    "path": "tools/dataset/encode_dataset.ipynb",
    "chars": 9544,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d0b72877\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Pre-encoding"
  },
  {
    "path": "tools/inference/inference_pipeline.ipynb",
    "chars": 16110,
    "preview": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"view-in-github\",\n        \"colab_t"
  },
  {
    "path": "tools/inference/run_infer_notebook.sh",
    "chars": 67,
    "preview": "#!/bin/bash\njupyter notebook --ip 0.0.0.0 --no-browser --allow-root"
  },
  {
    "path": "tools/train/config/mega/config.json",
    "chars": 1287,
    "preview": "{\n  \"activation_dropout\": 0.0,\n  \"activation_function\": \"gelu\",\n  \"attention_dropout\": 0.0,\n  \"bos_token_id\": 16385,\n  \""
  },
  {
    "path": "tools/train/config/micro/config.json",
    "chars": 751,
    "preview": "{\n  \"activation_dropout\": 0.0,\n  \"activation_function\": \"gelu\",\n  \"attention_dropout\": 0.0,\n  \"bos_token_id\": 16385,\n  \""
  },
  {
    "path": "tools/train/config/mini/config.json",
    "chars": 737,
    "preview": "{\n  \"activation_dropout\": 0.0,\n  \"activation_function\": \"gelu\",\n  \"attention_dropout\": 0.0,\n  \"bos_token_id\": 16385,\n  \""
  },
  {
    "path": "tools/train/config/mini_glu/config.json",
    "chars": 758,
    "preview": "{\n  \"activation_dropout\": 0.0,\n  \"activation_function\": \"gelu\",\n  \"attention_dropout\": 0.0,\n  \"bos_token_id\": 16385,\n  \""
  },
  {
    "path": "tools/train/embeddings_retrain_preparation.ipynb",
    "chars": 38894,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"118UKH5bWCGa\"\n   },\n   \"source\": [\n    \"# DALL"
  },
  {
    "path": "tools/train/scalable_shampoo/README.md",
    "chars": 264,
    "preview": "# Notes\n\nFiles copied from [google-research/scalable_shampoo/optax](https://github.com/google-research/google-research/t"
  },
  {
    "path": "tools/train/scalable_shampoo/distributed_shampoo.py",
    "chars": 96206,
    "preview": "# coding=utf-8\n# Copyright 2022 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "tools/train/scalable_shampoo/quantization_utils.py",
    "chars": 4585,
    "preview": "# coding=utf-8\n# Copyright 2022 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "tools/train/scalable_shampoo/sm3.py",
    "chars": 6157,
    "preview": "# coding=utf-8\n# Copyright 2022 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py",
    "chars": 14470,
    "preview": "# coding=utf-8\n# Copyright 2022 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "tools/train/sweep.yaml",
    "chars": 883,
    "preview": "program: train.py\nproject: dalle-mini\nmethod: random\nmetric:\n  name: eval/loss\n  goal: minimize\nparameters:\n  optim:\n   "
  },
  {
    "path": "tools/train/train.py",
    "chars": 65368,
    "preview": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2021-2022 The HuggingFace & DALL·E Mini team. All rights reserved.\n#\n# "
  }
]

About this extraction

This page contains the full source code of the borisdayma/dalle-mini GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 48 files (392.9 KB), approximately 96.0k tokens, and a symbol index with 182 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!