Full Code of LargeWorldModel/LWM for AI

main f45d2b70bda2 cached
26 files
244.6 KB
56.9k tokens
210 symbols
1 requests
Download .txt
Showing preview only (255K chars total). Download the full file or copy to clipboard to get everything.
Repository: LargeWorldModel/LWM
Branch: main
Commit: f45d2b70bda2
Files: 26
Total size: 244.6 KB

Directory structure:
gitextract_q2zwva4c/

├── .gitignore
├── LICENSE
├── README.md
├── docs/
│   ├── data.md
│   └── sharding.md
├── gpu_requirements.txt
├── lwm/
│   ├── __init__.py
│   ├── data.py
│   ├── llama.py
│   ├── train.py
│   ├── vision_chat.py
│   ├── vision_generation.py
│   ├── vision_llama.py
│   └── vqgan.py
├── scripts/
│   ├── create_needle_data.py
│   ├── eval_needle.py
│   ├── eval_needle_multi.py
│   ├── run_eval_needle.sh
│   ├── run_eval_needle_multi.sh
│   ├── run_sample_image.sh
│   ├── run_sample_video.sh
│   ├── run_train_text.sh
│   ├── run_train_vision_text.sh
│   ├── run_vision_chat.sh
│   └── sample_pyt.py
└── tpu_requirements.sh

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# local
jobs/
local/
.vscode/

data/
*.model
*.npy
*.jsonl
*.pkl
*.json
__pycache__/


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# Large World Model (LWM)

[[Project]](https://largeworldmodel.github.io/)
[[Paper]](https://arxiv.org/abs/2402.08268)
[[Models]](https://huggingface.co/LargeWorldModel)

**Large World Model (LWM)** is a general-purpose large-context multimodal autoregressive model. It is trained on a large dataset of diverse long videos and books using RingAttention, and can perform language, image, and video understanding and generation.


## Approach

<div align="center">
  <img src="./imgs/data.png"/>
</div>

Current language models fall short in understanding aspects of the world not easily described in words, and struggle with complex, long-form tasks. Video sequences offer valuable temporal information absent in language and static images, making them attractive for joint modeling with language. Such models could develop a understanding of both human textual knowledge and the physical world, enabling broader AI capabilities for assisting humans. However, learning from millions of tokens of video and language sequences poses challenges due to memory constraints, computational complexity, and limited datasets. To address these challenges, we curate a large dataset of diverse videos and books, utilize the RingAttention technique to scalably train on long sequences, and gradually increase context size from 4K to 1M tokens. This paper makes the following contributions: (a) Largest context size neural network: We train one of the largest context size transformers on long video and language sequences, setting new benchmarks in difficult retrieval tasks and long video understanding. (b) Solutions for overcoming vision-language training challenges, including using masked sequence packing for mixing different sequence lengths, loss weighting to balance language and vision, and model-generated QA dataset for long sequence chat. (c) A highly-optimized implementation with RingAttention, masked sequence packing, and other key features for training on millions-length multimodal sequences. (d) Fully open-sourced a family of 7B parameter models capable of processing long text documents (LWM-Text, LWM-Text-Chat) and videos (LWM, LWM-Chat) of over 1M tokens.
This work paves the way for training on massive datasets of long video and language to develop understanding of both human knowledge and the multimodal world, and broader capabilities.

## LWM Capabilities

<div align="center">
  <img src="./imgs/single_needle_1M.png"/>
  <p>
  LWM can retrieval facts across 1M context with high accuracy.
  </p>
</div>

<br />

<div align="center">
  <img src="./imgs/long_video_chat_main.png"/>
  <p>
  LWM can answer questions over 1 hour YouTube video.
  </p>
</div>

<br />

<div align="center">
  <img src="./imgs/image_chat.png"/>
  <p>
  LWM can chat with images.
  </p>
</div>

<br />

<div align="center">
  <img src="./imgs/image_video_gen.png"/>
  <p>
  LWM can generate videos and images from text.
  </p>
</div>


## Setup

This codebase is supported on Ubuntu and has not been tested on Windows or macOS. We recommend using TPUs for training and inference, although it is also possible to use GPUs. On TPU, the code is highly optimized with Jax's Pallas and can achieve high MFUs with RingAttention at very large context sizes. On GPU, the code is based on XLA and is not as optimized as it is for TPU.

Install the requirements with:
```
conda create -n lwm python=3.10
conda activate lwm
pip install -r gpu_requirements.txt
```
or set up TPU VM with:
```
sh tpu_requirements.sh
```


## Available models

There are language-only and video-language versions, offering context sizes from 32K, to 128K, 256K and 1M tokens. The vision-language models are available only in Jax, and the language-only models are available in both PyTorch and Jax. Below are the names of the available models and their corresponding context sizes and capabilities:

| Model Name         | Context Size | Language or Vision-Language | Chat or Base | URL                                                                                                                                          |
|--------------------|--------------|-----------------------------|--------------|----------------------------------------------------------------------------------------------------------------------------------------------|
| LWM-Text-Chat-128K | 128K         | Language                    | Chat         | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-128K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-128K-Jax)] |
| LWM-Text-Chat-256K | 256K         | Language                    | Chat         | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-256K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-256K-Jax)] |
| LWM-Text-Chat-512K | 512K         | Language                    | Chat         | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-512K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-512K-Jax)] |
| LWM-Text-Chat-1M   | 1M           | Language                    | Chat         | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-1M)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-1M-Jax)]     |
| LWM-Text-128K      | 128K         | Language                    | Base         | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-128K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-128K-Jax)]           |
| LWM-Text-256K      | 256K         | Language                    | Base         | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-256K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-256K-Jax)]           |
| LWM-Text-512K      | 512K         | Language                    | Base         | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-512K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-512K-Jax)]           |
| LWM-Text-1M        | 1M           | Language                    | Base         | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-1M)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-1M-Jax)]               |
| LWM-Chat-32K       | 32K          | Vision-Language             | Chat         | [[Jax](https://huggingface.co/LargeWorldModel/LWM-32K-Jax)]                                                                                  |
| LWM-Chat-128K      | 128K         | Vision-Language             | Chat         | [[Jax](https://huggingface.co/LargeWorldModel/LWM-128K-Jax)]                                                                                 |
| LWM-Chat-1M        | 1M           | Vision-Language             | Chat         | [[Jax](https://huggingface.co/LargeWorldModel/LWM-1M-Jax)]                                                                                   |


## Code structure
Use `scan_query_chunk_size` and `scan_key_chunk_size` to control the block size in blockwise compute of the self-attention. Use `scan_mlp_chunk_size` to control the block size in blockwise compute of the feedforward network. Use `scan_attention=True` and `scan_mlp=True` to enable/disable blockwise compute in the self-attention and feed-forward network.

You can use `mesh_dim=dp, fsdp, tp, sp` to control the degree of parallelism and RingAttention. It is a string of 4 integers separated by commas, representing the number of data parallelism, fully sharded data parallelism, tensor parallelism, and sequence parallelism.
For example, `mesh_dim='1,64,4,1'` means 1 data parallelism, 64 fully sharded data parallelism, 4 tensor parallelism, and 1 sequence parallelism. `mesh_dim='1,1,4,64'` means 1 data parallelism, 1 fully sharded data parallelism, 4 tensor parallelism, and 64 sequence parallelism for RingAttention.


## Running Jax Models
In this section, we provide instructions on how to run each of the provided scripts. For each script, you may need to fill in your own paths and values in the variables described in the beginning of each script.

To run each of the following scripts, use `bash <script_name>.sh`:
- Language model training: `bash scripts/run_train_text.sh`
- Vision-Language model training: `bash scripts/run_train_vision_text.sh`
- Single Needle Evals (Language Model): `bash scripts/run_eval_needle.sh`
- Multi Needle Evals (Language Model): `bash scripts/run_eval_needle_multi.sh`
- Sampling images (Vision-Language Model): `bash scripts/run_sample_image.sh`
- Sampling videos (Vision-LanguageModel): `bash scripts/run_sample_video.sh`
- Image / Video understanding (Vision-Language Model): `bash scripts/run_vision_chat.sh`

By default the `mesh_dim` argument puts all devices on `tp` (tensor parallelism). For longer sequences, you may want to include `sp`, which is the last dimension in the `mesh_dim`.

When running needle evals, you may need to adjust the `theta` and `max_sequence_length` arguments in the scripts depending on the model. Below shows the correct values for each model.

|                     | LWM-Text-128K /  LWM-Text-Chat-128K | LWM-Text-256K /  LWM-Text-Chat-256K | LWM-Text-512K / LWM-Text-Chat-512K | LWM-Text-1M / LWM-Text-Chat-1M |
|---------------------|:-----------------------------------:|:-----------------------------------:|:----------------------------------:|:------------------------------:|
| theta               |               10000000              |               10000000              |              25000000              |            50000000            |
| max_sequence_length |                131072               |                262144               |               524288               |             1048576            |


An example of filling out a script (`run_sample_video.sh`) is as follows
```bash
#! /bin/bash

export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"

export llama_tokenizer_path="LargeWorldModel/LWM-Text-1M"
export vqgan_checkpoint="/path/to/ckpt/folder/vqgan"
export lwm_checkpoint="params::/path/to/ckpt/folder/params"

python3 -u -m lwm.vision_generation \
    --prompt='Fireworks over the city' \
    --output_file='fireworks.mp4' \
    --temperature_image=1.0 \
    --temperature_video=1.0 \
    --top_k_image=8192 \
    --top_k_video=1000 \
    --cfg_scale_image=5.0 \
    --cfg_scale_video=1.0 \
    --vqgan_checkpoint="$vqgan_checkpoint" \
    --n_frames=8 \
    --mesh_dim='!1,1,-1,1' \
    --dtype='fp32' \
    --load_llama_config='7b' \
    --update_llama_config="dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \
    --load_checkpoint="$lwm_checkpoint" \
    --tokenizer="$llama_tokenizer_path"
read
```


## Needle Haystack Data
Run `python scripts/create_needle_data.py`


## Running PyTorch Models
Only text and text chat models are currently supported for PyTorch inference. PyTorch models can be loaded as Hugging Face `LlamaForCausalLM` models. Run `python scripts/sample_pyt.py` to sample. You may need to separately install `torch`.

## Documentation

For more details on the codebase, please refer to the [data.md](docs/data.md) and [sharding.md](docs/sharding.md).
The [data.md](docs/data.md) provides details on the data processing and the [sharding.md](docs/sharding.md) provides details on the sharding and parallelism.


## If you have issues

This is based on the [codebase](https://github.com/haoliuhl/ringattention) of RingAttention, with the necessary features for vision-language training. The training and inference have been tested on both TPUv3 and TPUv4.

If you encounter bugs, please open a GitHub issue!


## Citation

If you use this codebase, or otherwise found our work valuable, please cite:

```
@article{liu2023world,
    title={World Model on Million-Length Video and Language with RingAttention},
    author={Liu, Hao and Yan, Wilson and Zaharia, Matei and Abbeel, Pieter},
    journal={arXiv preprint},
    year={2024},
}
@article{liu2023ring,
    title={Ring Attention with Blockwise Transformers for Near-Infinite Context},
    author={Liu, Hao and Zaharia, Matei and Abbeel, Pieter},
    journal={International Conference on Learning Representations},
    year={2024}
}
@article{liu2023blockwise,
    title={Blockwise Parallel Transformer for Large Context Models},
    author={Liu, Hao and Abbeel, Pieter},
    journal={Advances in neural information processing systems},
    year={2023}
}
```

## License

LWM's code is released under the Apache 2.0 License. See [LICENSE](https://github.com/LargeWorldModel/lwm/blob/main/LICENSE) for further details. The models are released under the Llama-2 license.


================================================
FILE: docs/data.md
================================================
# Data

We support two types of datasets: Huggingface dataset and JSON dataset. The dataset modules are implemented in the [data.py](/lwm/data.py) file.

Configuration requires dataset type, text processor, and dataset specific configurations.
The following is an example of using the Huggingface dataset to train a model:
```bash
python -m lwm.train \
    --train_dataset.text_processor.fields='text' \
    --train_dataset.type='huggingface' \
    --train_dataset.huggingface_dataset.path='openwebtext'
```

In this example, we select the Huggingface dataset by specifying the `type` of
`train_dataset` to be `huggingface`. We then specify the path to the dataset,
which is `c4` in this case. The examples loaded from the dataset will be processed
by a TextProcessor, which is configured by the `text_processor` field.

The following options are supported for the dataset module:
* `type`: The type of the dataset. Supported values are `huggingface` and `json`.
* `text_processor`: The configuration of the TextProcessor used to process the
  loaded examples.
* `huggingface_dataset`: The configuration of the Huggingface dataset.
* `json_dataset`: The configuration of the JSON dataset.

For huggingface dataset, we expect loading examples from a Huggingface dataset.
* `path`: The path to the dataset. Same as the `path` argument in
  `datasets.load_dataset`.
* `name`: Name of the dataset within the path. Same as the `name` argument in
  `datasets.load_dataset`.
* `split`: The split of the dataset. Same as the `split` argument in
  `datasets.load_dataset`.
*  `streaming`: Whether to stream the dataset. Same as the `streaming` argument
   in `datasets.load_dataset`.
* `seq_length`: The length of the tokenized sequence.
* `batch_size`: Batch size of tokenized examples.

For JSON dataset, we expect loading examples from a text file, where each line where each line represents a
JSON encoded dictionary. Here are the configurable options for JSON dataset:
* `path`: Path to the text file. The file can be located on the local file system
  or on Google Cloud Storage bucket.
* `seq_length`: The length of the tokenized sequence.
* `batch_size`: Batch size of tokenized examples.
* `start_seek_loc`: The starting seek location in the file. This is useful when
  you want to resume training from a particular location in the file.
* `index_at_start`: The counting index at the beginning. This is useful to
  keep the index count when resuming from a particular location in the file.
  Note that this is only for logging purpose, and does not affect the actual
  examples starting from. To start from a different example in the dataset,
  you should use the `start_seek_loc` option.
* `tokenizer_processes`: The number of processes to use for tokenization.
  Tokenization is done in parallel to speed up the loading process.

A JSON dataset can be generated as follows:
```python
from datasets import load_dataset
import json
from multiprocessing import Pool, cpu_count

dataset = load_dataset("openwebtext")

split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
split_dataset['val'] = split_dataset.pop('test')

def save_split(split):
    with open(f"openwebtext_{split}.jsonl", "w") as f:
        for example in split_dataset[split]:
            json.dump({"text": example["text"]}, f)
            f.write("\n")

with Pool(cpu_count()) as p:
    p.map(save_split, ["train", "val"])
```

This generates two files, `openwebtext_train.jsonl` and `openwebtext_val.jsonl`, which can be used as the dataset for training. Both files contain a single field, `text`, which is the text to be processed by the model.
For example, to train a model using the `openwebtext_train.jsonl` file, you can use the following command:
```bash
python -m lwm.train \
    --train_dataset.text_processor.fields='text' \
    --train_dataset.type='json' \
    --train_dataset.json_dataset.path='openwebtext_train.jsonl' \
```

For vision-langauge training, we recommend using the JSON dataset, as it allows you to pre-tokenize vision (images and videos), and load the tokenized vision along with the text.

Each loaded example is a dictionary, which will be processed by a TextProcessor


## Text Processor
We use the `TextProcessor` class to process the loaded examples from a dataset. This allows us to flexibly process various formats.
Each input example is a dictionary of multiple text fields. The TextProcessor will
process text fields according to its configurations, and return the final tokens.

Here are the configurable options for TextProcessor:
* `fields`: A comma separated list of text fields to process.
* `fields_from_example`: Whether to use the keys of the input example as the
  text fields to process. If this option is set, the `fields` argument will
  be ignored.
* `subfield_separator`: The text separator to use when concatenating subfields
  of a texts.
* `add_eos_token`: Whether to add an EOS token to the end of the text.
* `prepend_text`: The text to prepended to the beginning.

The most important configuration for TextProcessor is the `fields` argument. It
is a comma separated list of text fields to process. Each field consists of one
or more subfields, which are separated by a `+`. Each subfield represent a key
used to extract the text from the input example dictionary. The TextProcessor
joins the extracted subfields of texts with the `subfield_separator` in the text
level and then tokenize the joined text. Finally, the TextProcessor will concatenate
the tokenized text fields at the token level, and add the EOS token if specified.

Other than the keys in the input example, you can also use the following special
keys to indicate a special token for a text field:
* `<|bos|>`: Beginning of sentence token.
* `<|eos|>`: End of sentence token.

For each text field, you can encapulate the subfields with `[]` to specify that
the loss should not be computed for this field. Doing so will make the loss
masks to be 0 for this field. This is useful when you want to use the text field
as a prompt for the model.


To give a concrete example, if the input example looks like this:
```python
{
    'vision': 'VQ tokens of a picture of a cat',
    'question': 'what is the color of the cat',
    'answer': 'the color of the cat is yellow',
}
```
To use the `vision` and `question` as the input text, and `answer` as the target,
we can specify the following configuration for the `fields` argument:
```
[vision+question],answer
```

The `vision+question` indicates that the `vision` and `question` should be joined
togather with the `subfield_separator`, which is a space by default. The `[]`
indicates that the loss should not be computed for this field. The `answer` field
is then concatenated at the token level, where the loss will be computed.



================================================
FILE: docs/sharding.md
================================================
# Sharding

Sharding is a technique to partition the computation and the model across multiple accelerators.
This codebase supports flexible model and data parallelism for training and serving.

The sharding can be specified using the `mesh_dim` command line argument. The `mesh_dim` is a
comma separated list of integers representing the parallelism mesh axis dimensions. One of the
axis dimensions can be `-1`, which means that the axis dimension will be inferred based on the
total number of accelerators.

The first axis of the mesh is used for data parallelism (`dp`), the second axis used for fully sharded
data parallelism (`fsdp`), the third axis is used for tensor parallelism (`tp`), the last axis is used for
sequence parallelism (required for ring attention) (`sp`).

For example, `mesh_dim='1,64,4,1'` means 1 data parallelism, 64 fully sharded data parallelism, 4 tensor parallelism, and 1 sequence parallelism. While `mesh_dim='1,1,4,64'` means 1 data parallelism, 1 fully sharded data parallelism, 4 tensor parallelism, and 64 sequence parallelism for RingAttention.

Your total number of accelerators should be equal to the product of the mesh dimensions. For example, `mesh_dim='1,64,4,1'` requires 64 accelerators, and `mesh_dim='1,1,4,64'` requires 256 accelerators.

In general, you want to use the largest possible mesh dimension for `fsdp`. Such as `mesh_dim='1,64,1,1'` is preferred over `mesh_dim='8,8,1,1'` because the former has larger `fsdp` dimensions, which allows overlapping of computation and communication, and thus better performance.

The batch size (number of sequences per batch) should be larger than or equal to `fsdp * dp`. If you think the batch size is too large, you can allocate more accelerators to `tp` and `sp` to increase the model size and sequence length.

Using `sp` to control the sequence parallelism is required to use RingAttention. `sp=8` means sharding sequence length by 8, and `sp=1` means no sharding.
 For models that use standard attention, you can set `sp=1` and use `dp`, `fsdp`, and `tp` to control the parallelism.


================================================
FILE: gpu_requirements.txt
================================================
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12]==0.4.29
flax==0.8.4
optax==0.2.2
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.0.0
transformers==4.40.0
ringattention @ git+https://github.com/haoliuhl/ringattention.git
datasets
einops
tqdm
ml_collections
wandb
gcsfs
requests
typing-extensions
sentencepiece
tux @ git+https://github.com/haoliuhl/tux.git
Pillow
ffmpeg-python
ipdb
imageio[ffmpeg]
opencv-python
decord
ffmpeg-python
h5py
psutil


================================================
FILE: lwm/__init__.py
================================================


================================================
FILE: lwm/data.py
================================================
import time
import random
from functools import partial
import json
from multiprocessing import Pool

from tux import open_file
from ml_collections import ConfigDict
import numpy as np
import jax
from jax.experimental.multihost_utils import host_local_array_to_global_array
from jax.sharding import PartitionSpec as PS
from datasets import load_dataset


class DatasetFactory(object):
    """ Datset builder class. """

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.type = 'huggingface'
        config.text_processor = TextProcessor.get_default_config()
        config.huggingface_dataset = HuggingfaceDataset.get_default_config()
        config.json_dataset = JsonDataset.get_default_config()

        config.vision_text_processor = VisionTextProcessor.get_default_config()
        config.json_vision_dataset = JsonVisionDataset.get_default_config()

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    @classmethod
    def load_dataset(cls, config, tokenizer, **kwargs):
        config = cls.get_default_config(config)
        if config.type == 'huggingface':
            text_processor = TextProcessor(config.text_processor, tokenizer)
            return HuggingfaceDataset(
                config.huggingface_dataset, tokenizer, text_processor, **kwargs
            )
        elif config.type == 'json':
            text_processor = TextProcessor(config.text_processor, tokenizer)
            return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs)
        elif config.type == 'json_vision':
            vision_text_processor = VisionTextProcessor(config.vision_text_processor, tokenizer)
            return JsonVisionDataset(config.json_vision_dataset, tokenizer, vision_text_processor, **kwargs)
        else:
            raise ValueError(f'Unknown dataset type: {config.type}')

    def __init__(self):
        raise ValueError('DatasetFactory is a static class and should not be instantiated.')


class TextProcessor(object):
    """ Example processor that converts a dictionary of texts into tokens. """
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.fields_from_example = ''
        config.fields = ''
        config.subfield_separator = ' '
        config.add_bos_token = True
        config.add_eos_token = True
        config.prepend_text = ''
        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config, tokenizer):
        self.config = self.get_default_config(config)
        assert self.config.fields != '' or self.config.fields_from_example != '', (
            'Either fields or fields_from_example must be specified.'
        )
        self.tokenizer = tokenizer

    def __call__(self, example, has_aux=False, add_bos_token=True, add_eos_token=True):
        if has_aux:
            example, *aux = example
        else:
            aux = tuple()
        token_buffer = []
        loss_mask_buffer = []

        if add_bos_token and self.config.add_bos_token:
            token_buffer.append(self.tokenizer.bos_token_id)
            loss_mask_buffer.append(0.0)

        if self.config.fields_from_example != '':
            fields = example[self.config.fields_from_example].split(',')
        else:
            fields = self.config.fields.split(',')

        for i, field in enumerate(fields):
            if field.startswith('[') and field.endswith(']'):
                # No loss for this field.
                field = field[1:-1]
                mask = 0.0
            else:
                mask = 1.0

            if field == '<|bos|>':
                token_buffer.append(self.tokenizer.bos_token_id)
                loss_mask_buffer.append(mask)
            elif field == '<|eos|>':
                token_buffer.append(self.tokenizer.eos_token_id)
                loss_mask_buffer.append(mask)
            else:
                subfields = field.split('+')
                text = self.config.subfield_separator.join(
                    [example[subfield] for subfield in subfields]
                )
                if i == 0:
                    text = self.config.prepend_text + text
                tokens = self.tokenizer.encode(text, add_special_tokens=False)
                token_buffer.extend(tokens)
                loss_mask_buffer.extend([mask for _ in range(len(tokens))])

        if add_eos_token and self.config.add_eos_token:
            token_buffer.append(self.tokenizer.eos_token_id)
            loss_mask_buffer.append(1.0)

        return token_buffer, loss_mask_buffer, *aux


class VisionTextProcessor(object):
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.fields_from_example = ''
        config.subfield_separator = ' '
        config.add_bos_token = True
        config.add_eos_token = True
        config.prepend_text = ''
        config.fields_index = -1
        config.eof_token = 8192 # denotes end of each frame for video generation
        config.eov_token = 8193 # denotes end of vision generation
        config.n_tokens_per_frame = 256 # 16 x 16 VQ codes
        config.max_n_frames = -1
        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config, tokenizer):
        self.config = self.get_default_config(config)
        assert self.config.fields_from_example != '', (
            'fields_from_example must be specified.'
        )
        self.tokenizer = tokenizer
        self.vision_start = tokenizer.encode('<vision>')
        self.vision_end = tokenizer.encode('</vision>')

    def __call__(self, example, has_aux=False, add_bos_token=True, add_eos_token=True):
        if has_aux:
            example, *aux = example
        else:
            aux = tuple()
        rand_state = random.Random(aux[-1]) # makes augmentations deterministic by line number
        token_buffer = []
        loss_mask_buffer = []
        vision_mask = []

        fields = example[self.config.fields_from_example]
        if isinstance(fields, (tuple, list)):
            if self.config.fields_index >= 0:
                fields = fields[self.config.fields_index]
            else:
                # seed based on line number
                fields = rand_state.choice(fields)
        fields = fields.split(',')

        if add_bos_token and self.config.add_bos_token:
            token_buffer.append(self.tokenizer.bos_token_id)
            loss_mask_buffer.append(0.0)
            vision_mask.append(False)

        for i, field in enumerate(fields):
            if field.startswith('[') and field.endswith(']'):
                # No loss for this field.
                field = field[1:-1]
                mask = 0.0
            else:
                mask = 1.0

            if field == '<|bos|>':
                token_buffer.append(self.tokenizer.bos_token_id)
                loss_mask_buffer.append(mask)
                vision_mask.append(False)
            elif field == '<|eos|>':
                token_buffer.append(self.tokenizer.eos_token_id)
                loss_mask_buffer.append(mask)
                vision_mask.append(False)
            elif 'vision' in field:
                vision_tokens = example[field]
                n_frames = int(len(vision_tokens) / self.config.n_tokens_per_frame)
                if self.config.max_n_frames > 0 and n_frames > self.config.max_n_frames: # uniformly select
                    idxs = np.linspace(0, n_frames - 1, self.config.max_n_frames).astype(int)
                    new_vision_tokens = []
                    for idx in idxs:
                        new_vision_tokens.extend(vision_tokens[idx * self.config.n_tokens_per_frame:(idx + 1) * self.config.n_tokens_per_frame])
                    vision_tokens = new_vision_tokens
                    n_frames = self.config.max_n_frames
                assert int(len(vision_tokens) / self.config.n_tokens_per_frame) == n_frames, (int(len(vision_tokens) / self.config.n_tokens_per_frame), n_frames)

                assert n_frames > 0, len(vision_tokens)
                tokens = list(self.vision_start)
                for j in range(n_frames):
                    tokens.extend(vision_tokens[j*self.config.n_tokens_per_frame:(j+1)*self.config.n_tokens_per_frame])
                    if j == n_frames - 1: # last frame
                        tokens.append(self.config.eov_token)
                    else:
                        tokens.append(self.config.eof_token)
                tokens.extend(self.vision_end)

                token_buffer.extend(tokens)
                loss_mask_buffer.extend([mask for _ in range(len(tokens))])
                vision_mask.extend([False] * len(self.vision_start))
                vision_mask.extend([True] * (self.config.n_tokens_per_frame * n_frames + n_frames)) # include extra eof/eov token at the end of each frame
                vision_mask.extend([False] * len(self.vision_end))
            else:
                subfields = field.split('+')
                text = self.config.subfield_separator.join(
                    [example[subfield] for subfield in subfields]
                )
                if i == 0:
                    text = self.config.prepend_text + text
                tokens = self.tokenizer.encode(text)
                token_buffer.extend(tokens)
                loss_mask_buffer.extend([mask for _ in range(len(tokens))])
                vision_mask.extend([False] * len(tokens))

        if add_eos_token and self.config.add_eos_token:
            token_buffer.append(self.tokenizer.eos_token_id)
            loss_mask_buffer.append(1.0)
            vision_mask.append(False)

        assert len(token_buffer) == len(loss_mask_buffer) == len(vision_mask), (len(token_buffer), len(loss_mask_buffer), len(vision_mask))
        keep = True
        return token_buffer, loss_mask_buffer, vision_mask, keep, *aux


class HuggingfaceDataset(object):
    """ Huggingface dataset, where the dataset is loaded using the huggingface
        datasets.load_dataset() function.
    """

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.path = 'c4'
        config.name = 'en'
        config.split = 'train'
        config.streaming = False
        config.seq_length = 1024
        config.batch_size = 8
        config.always_start_with_bos = False

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config, tokenizer, text_processor):
        self.config = self.get_default_config(config)
        name = self.config.name if self.config.name != '' else None
        split = self.config.split if self.config.split != '' else None
        self._tokenizer = tokenizer
        self._text_processor = text_processor
        self._dataset = load_dataset(
            self.config.path, name, split=split, streaming=self.config.streaming
        )

    def __iter__(self):
        chunk_size = self.config.batch_size * self.config.seq_length
        total_tokens = 0
        while True:
            token_buffer = []
            loss_mask_buffer = []
            for index, example in enumerate(self._dataset):
                tokens, loss_masks = self.text_processor(example)
                token_buffer.extend(tokens)
                loss_mask_buffer.extend(loss_masks)
                while len(token_buffer) > chunk_size + 1:
                    total_tokens += chunk_size
                    metrics = {
                        'dataset_example_index': index,
                        'dataset_total_tokens': total_tokens,
                    }
                    batch = {
                        'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
                            self.config.batch_size, -1
                        ),
                        'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
                            self.config.batch_size, -1
                        ),
                        'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
                            self.config.batch_size, -1
                        ),
                    }
                    if self.config.always_start_with_bos:
                        batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
                    yield batch, metrics
                    token_buffer = token_buffer[chunk_size:]
                    loss_mask_buffer = loss_mask_buffer[chunk_size:]

    def get_state_dict(self):
        return dict(config=self.config)

    def load_state_dict(self, state_dict):
        if 'config' in state_dict:
            self.config.update(ConfigDict(state_dict['config']))

    @property
    def seq_length(self):
        return self.config.seq_length

    @property
    def tokenizer(self):
        return self._tokenizer

    @property
    def text_processor(self):
        return self._text_processor

    @property
    def dataset(self):
        return self._dataset

    @property
    def vocab_size(self):
        return len(self._tokenizer)


class JsonDataset(object):
    """ JSON dataset, where each line of the data file contains a JSON
        dictionary with text fields.
    """

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.path = ''
        config.seq_length = 1024
        config.batch_size = 8
        config.always_start_with_bos = False
        config.start_seek_loc = 0
        config.example_index_at_start = 0
        config.tokens_count_at_start = 0
        config.tokenizer_processes = 1
        config.tokenizer_parallel_chunk_size = 32
        config.tokenizer_parallel_batch_size = 1024
        config.throughput_average_window_size = 200
        config.pad = False
        config.use_data_sharded_loader = True
        config.return_local_batch = False

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config, tokenizer, text_processor, node_info):
        self.config = self.get_default_config(config)
        assert self.config.path != ''
        self._tokenizer = tokenizer
        self._text_processor = text_processor
        self._node_info = node_info
        self._index = self.config.example_index_at_start
        self._file_loc = self.config.start_seek_loc
        self._total_tokens = self.config.tokens_count_at_start

    def parse_json(self, line):
        if not line or line == '\n':
            return None
        try:
            data = json.loads(line)
        except json.decoder.JSONDecodeError:
            print(f'Error parsing json line:\n{line}')
            return None
        return data

    def json_iterator(self):
        index, file_loc = self._index, self._file_loc
        with open_file(self.config.path, 'r') as fin:
            fin.seek(file_loc)
            while True:
                line = fin.readline()
                file_loc = fin.tell()
                if not line:   # Reached EOF
                    index = 0
                    fin.seek(0)
                    continue

                data = self.parse_json(line)
                if data is not None and (not self.config.use_data_sharded_loader or index % self._node_info['dp_node_size'] == self._node_info['dp_node_rank']):
                    # JSON parsing succeeded
                    yield data, file_loc, index
                index += 1

    def batched(self, iterator, batch_size):
        batch = []
        for example in iterator:
            batch.append(example)
            if len(batch) == batch_size:
                yield batch
                batch = []
        if len(batch) > 0:
            yield batch

    def parallel_example_iterator(self):
        if self.config.tokenizer_processes == 1:
            for example, loc, index in self.json_iterator():
                self._file_loc = loc
                self._index = index
                yield self.text_processor((example, loc, index), has_aux=True)
        else:
            process_pool = Pool(self.config.tokenizer_processes)
            batched_iterator = self.batched(
                self.json_iterator(), self.config.tokenizer_parallel_batch_size
            )
            with process_pool as pool:
                map_fn = partial(self.text_processor, has_aux=True)
                next_batch = pool.map_async(
                    map_fn, next(batched_iterator),
                    chunksize=self.config.tokenizer_parallel_chunk_size
                )
                while True:
                    current_batch = next_batch
                    next_batch = pool.map_async(
                        map_fn, next(batched_iterator),
                        chunksize=self.config.tokenizer_parallel_chunk_size
                    )
                    for example in current_batch.get():
                        yield example

    def __iter__(self):
        global_chunk_size = self.config.batch_size * self.config.seq_length
        if self.config.use_data_sharded_loader:
            local_batch_size = self.config.batch_size // self._node_info['dp_node_size']
        else:
            local_batch_size = self.config.batch_size
        chunk_size = local_batch_size * self.config.seq_length

        token_buffer = []
        loss_mask_buffer = []

        last_time = 0.0
        step_times = []
        start_time = time.time()
        start_tokens = self._total_tokens
        for tokens, loss_masks, loc, index in self.parallel_example_iterator():
            self._file_loc = loc
            self._index = index
            if self.config.pad:
                tokens = tokens[:self.config.seq_length + 1]
                tokens.extend([self._tokenizer.bos_token_id] * (self.config.seq_length + 1 - len(tokens)))
                loss_masks = loss_masks[:self.config.seq_length + 1]
                loss_masks.extend([0.0] * (self.config.seq_length + 1 - len(loss_masks)))
            token_buffer.extend(tokens)
            loss_mask_buffer.extend(loss_masks)
            while len(token_buffer) > chunk_size + 1:
                self._total_tokens += global_chunk_size
                step_times.append(time.time() - last_time)
                last_time = time.time()
                if len(step_times) > self.config.throughput_average_window_size:
                    step_times = step_times[-self.config.throughput_average_window_size:]
                average_throughput = global_chunk_size / np.mean(step_times)
                accumulated_throughput = (
                    (self._total_tokens - start_tokens) / (time.time() - start_time)
                )
                metrics = {
                    'dataset_file_loc': loc,
                    'dataset_example_index': index,
                    'dataset_total_tokens': self._total_tokens,
                    'dataset_accumulated_tps': accumulated_throughput,
                    'dataset_average_tps': average_throughput,
                }
                batch = {
                    'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
                        local_batch_size, -1
                    ),
                    'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
                        local_batch_size, -1
                    ),
                    'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
                        local_batch_size, -1
                    ),
                }
                batch.update({
                    'input_vision_masks': np.zeros(batch['input_tokens'].shape, dtype=bool),
                    'target_vision_masks': np.zeros(batch['input_tokens'].shape, dtype=bool),
                })
                if self.config.always_start_with_bos:
                    batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id

                if self.config.use_data_sharded_loader and not self.config.return_local_batch:
                    mesh = self._node_info['mesh']
                    sp_nodes_size = max(1, mesh.shape['sp'] // jax.local_device_count())
                    sp_nodes_rank = jax.process_index() % sp_nodes_size
                    assert self.config.seq_length % sp_nodes_size == 0, (self.config.seq_len, sp_nodes_size)
                    seq_chunk_size = self.config.seq_length // sp_nodes_size
                    batch = {k: v[:, sp_nodes_rank*seq_chunk_size:(sp_nodes_rank+1)*seq_chunk_size] for k, v in batch.items()}
                    batch = host_local_array_to_global_array(batch, self._node_info['mesh'], PS(('dp', 'fsdp'), 'sp'))

                yield batch, metrics
                if self.config.pad:
                    token_buffer, loss_mask_buffer = [], []
                else:
                    token_buffer = token_buffer[chunk_size:]
                    loss_mask_buffer = loss_mask_buffer[chunk_size:]

    def _make_callback(self, v):
        return lambda index: v[index]

    def get_state_dict(self):
        return dict(
            config=self.config,
            index=self._index,
            file_loc=self._file_loc,
            total_tokens=self._total_tokens,
        )

    def load_state_dict(self, state_dict):
        if 'config' in state_dict:
            self.config.update(ConfigDict(state_dict['config']))
        self._index = state_dict.get('index', self.config.example_index_at_start)
        self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)
        self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)

    @property
    def seq_length(self):
        return self.config.seq_length

    @property
    def tokenizer(self):
        return self._tokenizer

    @property
    def text_processor(self):
        return self._text_processor

    @property
    def vocab_size(self):
        return len(self.tokenizer)


class JsonVisionDataset(object):
    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.path = ''
        config.seq_length = 384
        config.batch_size = 4
        config.always_start_with_bos = False
        config.start_seek_loc = 0
        config.example_index_at_start = 0
        config.tokens_count_at_start = 0
        config.tokenizer_processes = 1
        config.tokenizer_parallel_chunk_size = 32
        config.tokenizer_parallel_batch_size = 1024
        config.throughput_average_window_size = 200
        config.use_data_sharded_loader = True
        config.return_local_batch = False
        config.mode = 'pad'

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config, tokenizer, text_processor, node_info):
        self.config = self.get_default_config(config)
        assert self.config.path != ''
        self._node_info = node_info
        self._tokenizer = tokenizer
        self._text_processor = text_processor
        self._index = self.config.example_index_at_start
        self._file_loc = self.config.start_seek_loc
        self._total_tokens = 0

    def parse_json(self, line):
        if not line or line == '\n':
            return None
        try:
            data = json.loads(line)
        except json.decoder.JSONDecodeError:
            print(f'Error parsing json line:\n{line}')
            return None
        return data

    def json_iterator(self):
        index, file_loc = self._index, self._file_loc
        with open_file(self.config.path, 'r', block_size=50 * 2 ** 20) as fin:
            fin.seek(file_loc)
            while True:
                line = fin.readline()
                file_loc = fin.tell()
                if not line:   # Reached EOF
                    index = 0
                    fin.seek(0)
                    continue
                if not self.config.use_data_sharded_loader or index % self._node_info['dp_node_size'] == self._node_info['dp_node_rank']:
                    data = self.parse_json(line)
                    if data is not None:
                        # JSON parsing succeeded
                        yield data, file_loc, index
                index += 1

    def batched(self, iterator, batch_size):
        batch = []
        for example in iterator:
            batch.append(example)
            if len(batch) == batch_size:
                yield batch
                batch = []
        if len(batch) > 0:
            yield batch

    def parallel_example_iterator(self):
        if self.config.tokenizer_processes == 1:
            for example, loc, index in self.json_iterator():
                self._file_loc = loc
                self._index = index
                yield self.text_processor((example, loc, index), has_aux=True)
        else:
            process_pool = Pool(self.config.tokenizer_processes)
            batched_iterator = self.batched(
                self.json_iterator(), self.config.tokenizer_parallel_batch_size
            )
            with process_pool as pool:
                map_fn = partial(self.text_processor, has_aux=True)
                next_batch = pool.map_async(
                    map_fn, next(batched_iterator),
                    chunksize=self.config.tokenizer_parallel_chunk_size
                )
                while True:
                    current_batch = next_batch
                    next_batch = pool.map_async(
                        map_fn, next(batched_iterator),
                        chunksize=self.config.tokenizer_parallel_chunk_size
                    )
                    for example in current_batch.get():
                        yield example

    def __iter__(self):
        if self.config.mode == 'pad':
            fn = self._iter_pad
        elif self.config.mode == 'no_pad':
            fn = self._iter_no_pad
        else:
            raise ValueError(f'Unknown mode: {self.config.mode}')
        return fn()

    def _iter_pad(self):
        chunk_size = self.config.batch_size * self.config.seq_length
        if self.config.use_data_sharded_loader:
            local_batch_size = self.config.batch_size // self._node_info['dp_node_size']
        else:
            local_batch_size = self.config.batch_size
        last_time = 0.0
        buffer = []
        step_times = []
        start_time = time.time()
        start_tokens = self._total_tokens
        for tokens, loss_masks, vision_masks, keep, loc, index in self.parallel_example_iterator():
            if not keep:
                continue
            self._file_loc = loc
            self._index = index
            buffer.append((tokens, loss_masks, vision_masks))
            while len(buffer) >= local_batch_size:
                self._total_tokens += chunk_size
                step_times.append(time.time() - last_time)
                last_time = time.time()
                if len(step_times) > self.config.throughput_average_window_size:
                    step_times = step_times[-self.config.throughput_average_window_size:]
                average_throughput = chunk_size / np.mean(step_times)
                accumulated_throughput = (
                    (self._total_tokens - start_tokens) / (time.time() - start_time)
                )
                metrics = {
                    'dataset_file_loc': loc,
                    'dataset_example_index': index,
                    'dataset_total_tokens': self._total_tokens,
                    'dataset_accumulated_tps': accumulated_throughput,
                    'dataset_average_tps': average_throughput,
                }

                batch = {
                    'input_tokens': np.full(
                        (local_batch_size, self.config.seq_length),
                        self._tokenizer.bos_token_id,
                        dtype=np.int32
                    ),
                    'target_tokens': np.full(
                        (local_batch_size, self.config.seq_length),
                        self._tokenizer.bos_token_id,
                        dtype=np.int32
                    ),
                    'loss_masks': np.zeros(
                        (local_batch_size, self.config.seq_length),
                        dtype=np.float32
                    ),
                    'input_vision_masks': np.zeros(
                        (local_batch_size, self.config.seq_length),
                        dtype=bool
                    ),
                    'target_vision_masks': np.zeros(
                        (local_batch_size, self.config.seq_length),
                        dtype=bool
                    )
                }
                for i in range(local_batch_size):
                    tokens, loss_masks, vision_masks = buffer[i]
                    if len(tokens) > self.config.seq_length:
                        tokens = tokens[:self.config.seq_length + 1]
                        loss_masks = loss_masks[1:self.config.seq_length + 1]
                        vision_masks = vision_masks[:self.config.seq_length + 1]
                    input_tokens, target_tokens = tokens[:-1], tokens[1:]
                    input_vision_masks, target_vision_masks = vision_masks[:-1], vision_masks[1:]
                    loss_masks = loss_masks[1:]
                    batch['input_tokens'][i, :len(input_tokens)] = input_tokens
                    batch['target_tokens'][i, :len(target_tokens)] = target_tokens
                    batch['input_vision_masks'][i, :len(input_vision_masks)] = input_vision_masks
                    batch['target_vision_masks'][i, :len(target_vision_masks)] = target_vision_masks
                    batch['loss_masks'][i, :len(loss_masks)] = loss_masks

                if self.config.use_data_sharded_loader and not self.config.return_local_batch:
                    mesh = self._node_info['mesh']
                    sp_nodes_size = max(1, mesh.shape['sp'] // jax.local_device_count())
                    sp_nodes_rank = jax.process_index() % sp_nodes_size
                    assert self.config.seq_length % sp_nodes_size == 0, (self.config.seq_len, sp_nodes_size)
                    seq_chunk_size = self.config.seq_length // sp_nodes_size
                    batch = {k: v[:, sp_nodes_rank*seq_chunk_size:(sp_nodes_rank+1)*seq_chunk_size] for k, v in batch.items()}
                    batch = host_local_array_to_global_array(batch, self._node_info['mesh'], PS(('dp', 'fsdp'), 'sp'))
                yield batch, metrics
                buffer = buffer[local_batch_size:]

    def _iter_no_pad(self):
        global_chunk_size = self.config.batch_size * self.config.seq_length
        if self.config.use_data_sharded_loader:
            local_batch_size = self.config.batch_size // self._node_info['dp_node_size']
        else:
            local_batch_size = self.config.batch_size
        chunk_size = local_batch_size * self.config.seq_length

        token_buffer = []
        loss_mask_buffer = []
        vision_mask_buffer = []

        last_time = 0.0
        step_times = []
        start_time = time.time()
        start_tokens = self._total_tokens
        for tokens, loss_masks, vision_masks, keep, loc, index in self.parallel_example_iterator():
            if not keep:
                continue
            self._file_loc = loc
            self._index = index
            token_buffer.extend(tokens)
            loss_mask_buffer.extend(loss_masks)
            vision_mask_buffer.extend(vision_masks)
            while len(token_buffer) > chunk_size + 1:
                self._total_tokens += global_chunk_size
                step_times.append(time.time() - last_time)
                last_time = time.time()
                if len(step_times) > self.config.throughput_average_window_size:
                    step_times = step_times[-self.config.throughput_average_window_size:]
                average_throughput = global_chunk_size / np.mean(step_times)
                accumulated_throughput = (
                    (self._total_tokens - start_tokens) / (time.time() - start_time)
                )
                metrics = {
                    'dataset_file_loc': loc,
                    'dataset_example_index': index,
                    'dataset_total_tokens': self._total_tokens,
                    'dataset_accumulated_tps': accumulated_throughput,
                    'dataset_average_tps': average_throughput,
                }
                batch = {
                    'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
                        local_batch_size, -1
                    ),
                    'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
                        local_batch_size, -1
                    ),
                    'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
                        local_batch_size, -1
                    ),
                    'input_vision_masks': np.array(vision_mask_buffer[:chunk_size], dtype=bool).reshape(
                        local_batch_size, -1
                    ),
                    'target_vision_masks': np.array(vision_mask_buffer[1:chunk_size + 1], dtype=bool).reshape(
                        local_batch_size, -1
                    ),
                }

                if self.config.use_data_sharded_loader and not self.config.return_local_batch:
                    mesh = self._node_info['mesh']
                    sp_nodes_size = max(1, mesh.shape['sp'] // jax.local_device_count())
                    sp_nodes_rank = jax.process_index() % sp_nodes_size
                    assert self.config.seq_length % sp_nodes_size == 0, (self.config.seq_len, sp_nodes_size)
                    seq_chunk_size = self.config.seq_length // sp_nodes_size
                    batch = {k: v[:, sp_nodes_rank*seq_chunk_size:(sp_nodes_rank+1)*seq_chunk_size] for k, v in batch.items()}
                    batch = host_local_array_to_global_array(batch, self._node_info['mesh'], PS(('dp', 'fsdp'), 'sp'))

                yield batch, metrics
                token_buffer = token_buffer[chunk_size:]
                loss_mask_buffer = loss_mask_buffer[chunk_size:]
                vision_mask_buffer = vision_mask_buffer[chunk_size:]


    def _make_callback(self, v):
        return lambda index: v[index]

    def get_state_dict(self):
        return dict(
            config=self.config,
            index=self._index,
            file_loc=self._file_loc,
            total_tokens=self._total_tokens,
        )

    def load_state_dict(self, state_dict):
        if 'config' in state_dict:
            self.config.update(ConfigDict(state_dict['config']))
        self._index = state_dict.get('index', self.config.example_index_at_start)
        self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)
        self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)

    @property
    def seq_length(self):
        return self.config.seq_length

    @property
    def tokenizer(self):
        return self._tokenizer

    @property
    def text_processor(self):
        return self._text_processor

    @property
    def vocab_size(self):
        return len(self._tokenizer)


================================================
FILE: lwm/llama.py
================================================
import os
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import json
import tempfile
from functools import partial

import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
from jax.sharding import PartitionSpec as PS
from jax.experimental.shard_map import shard_map
import flax.linen as nn
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.linen import partitioning as nn_partitioning

import sentencepiece as spm
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from transformers.modeling_flax_utils import FlaxPreTrainedModel
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging

from ml_collections import ConfigDict
from tux import function_args_to_config, load_pickle, open_file,  with_sharding_constraint, get_jax_mesh
from ringattention import ringattention, blockwise_feedforward, ringattention_jax, ringattention_inference


LLAMA_STANDARD_CONFIGS = {
    '200m': {
        'vocab_size': 32000,
        'hidden_size': 1024,
        'intermediate_size': 2048,
        'num_hidden_layers': 14,
        'num_attention_heads': 8,
        'max_sequence_length': 2048,
        'initializer_range': 0.02,
        'rms_norm_eps': 1e-6,
        'use_cache': True,
        'tie_word_embeddings': False,
    },
    '1b': {
        'vocab_size': 32000,
        'hidden_size': 2048,
        'intermediate_size': 5504,
        'num_hidden_layers': 22,
        'num_attention_heads': 16,
        'max_sequence_length': 2048,
        'initializer_range': 0.02,
        'rms_norm_eps': 1e-6,
        'use_cache': True,
        'tie_word_embeddings': False,
    },
    '3b': {
        'vocab_size': 32000,
        'hidden_size': 3200,
        'intermediate_size': 8640,
        'num_hidden_layers': 26,
        'num_attention_heads': 32,
        'max_sequence_length': 2048,
        'initializer_range': 0.02,
        'rms_norm_eps': 1e-6,
        'use_cache': True,
        'tie_word_embeddings': False,
    },
    '7b': {
        'vocab_size': 32000,
        'hidden_size': 4096,
        'intermediate_size': 11008,
        'num_hidden_layers': 32,
        'num_attention_heads': 32,
        'max_sequence_length': 4096,
        'initializer_range': 0.02,
        'rms_norm_eps': 1e-6,
        'use_cache': True,
        'tie_word_embeddings': False,
    },
    '13b': {
        'vocab_size': 32000,
        'hidden_size': 5120,
        'intermediate_size': 13824,
        'num_hidden_layers': 40,
        'num_attention_heads': 40,
        'max_sequence_length': 2048,
        'initializer_range': 0.02,
        'rms_norm_eps': 1e-6,
        'use_cache': True,
        'tie_word_embeddings': False,
    },
    '30b': {
        'vocab_size': 32000,
        'hidden_size': 6656,
        'intermediate_size': 17920,
        'num_hidden_layers': 60,
        'num_attention_heads': 52,
        'max_sequence_length': 2048,
        'initializer_range': 0.02,
        'rms_norm_eps': 1e-6,
        'use_cache': True,
        'tie_word_embeddings': False,
    },
    '65b': {
        'vocab_size': 32000,
        'hidden_size': 8192,
        'intermediate_size': 22016,
        'num_hidden_layers': 80,
        'num_attention_heads': 64,
        'max_sequence_length': 2048,
        'initializer_range': 0.02,
        'rms_norm_eps': 1e-5,
        'use_cache': True,
        'tie_word_embeddings': False,
    },
    'debug': { # A small model for debugging
        'vocab_size': 32000,
        'hidden_size': 256,
        'intermediate_size': 256,
        'num_hidden_layers': 2,
        'num_attention_heads': 2,
        'max_sequence_length': 2048,
        'initializer_range': 0.02,
        'rms_norm_eps': 1e-6,
        'use_cache': True,
        'tie_word_embeddings': False,
    },
}


class LLaMAConfig(PretrainedConfig):
    model_type = "llama"

    def __init__(
        self,
        vocab_size=32000,
        hidden_size=4096,
        intermediate_size=11008,
        num_hidden_layers=32,
        num_attention_heads=32,
        max_sequence_length=4096,
        rms_norm_eps=1e-6,
        initializer_range=0.02,
        use_cache=True,
        bos_token_id=0,
        eos_token_id=1,
        resid_pdrop=0.0,
        embd_pdrop=0.0,
        attn_pdrop=0.0,
        tie_word_embeddings=False,
        scan_attention=True,
        scan_mlp=True,
        scan_query_chunk_size=1024,
        scan_key_chunk_size=1024,
        scan_mlp_chunk_size=1024,
        scan_layers=True,
        param_scan_axis=0,
        mesh_dim=None,
        theta=10000,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.initializer_range = initializer_range
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.max_sequence_length = max_sequence_length
        self.rms_norm_eps = rms_norm_eps
        self.use_cache = use_cache
        self.resid_pdrop = resid_pdrop
        self.embd_pdrop = embd_pdrop
        self.attn_pdrop = attn_pdrop
        self.scan_attention = scan_attention
        self.scan_mlp = scan_mlp
        self.scan_query_chunk_size = scan_query_chunk_size
        self.scan_key_chunk_size = scan_key_chunk_size
        self.scan_mlp_chunk_size = scan_mlp_chunk_size
        self.scan_layers = scan_layers
        self.param_scan_axis = param_scan_axis
        self.mesh_dim = mesh_dim
        self.theta = theta
        super().__init__(
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )

    @classmethod
    def get_default_config(cls, updates=None):
        config = function_args_to_config(cls.__init__)

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())

        return config

    @staticmethod
    def get_jax_mesh(axis_dims):
        return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'tp', 'sp'))

    @staticmethod
    def get_ranks_and_size(mesh):
        out = dict(mesh=mesh)
        mp_size = mesh.shape['tp'] * mesh.shape['sp']
        mp_node_size = max(1, mp_size // jax.local_device_count())
        dp_node_size = jax.process_count() // mp_node_size
        out.update(mp_node_size=mp_node_size,
                   dp_node_size=dp_node_size)

        dp_node_rank = jax.process_index() // mp_node_size
        mp_node_rank = jax.process_index() % mp_node_size
        out.update(dp_node_rank=dp_node_rank,
                   mp_node_rank=mp_node_rank)
        return out


    @staticmethod
    def get_partition_rules(scan_layers=False, scan_axis=0):
        """Parition rules are orderd, so that the beginning rules match first."""
        if scan_layers:
            if scan_axis == 0:
                return (
                    # embeddings
                    ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
                    # atention
                    ("attention/(wq|wk|wv)/kernel", PS(None, ("fsdp", "sp"), "tp")),
                    ("attention/wo/kernel", PS(None, "tp", ("fsdp", "sp"))),
                    # mlp
                    ("feed_forward/w1/kernel", PS(None, ("fsdp", "sp"), "tp")),
                    ("feed_forward/w2/kernel", PS(None, "tp", ("fsdp", "sp"))),
                    ("feed_forward/w3/kernel", PS(None, ("fsdp", "sp"), "tp")),
                    # layer norms
                    ("attention_norm/kernel", PS(None, None)),
                    ("ffn_norm/kernel", PS(None, None)),
                    # output head
                    ("transformer/ln_f/kernel", PS(None)),
                    ("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
                    ('.*', PS(None)),
                )
            elif scan_axis == 1:
                return (
                    # embeddings
                    ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
                    # atention
                    ("attention/(wq|wk|wv)/kernel", PS(("fsdp", "sp"), None, "tp")),
                    ("attention/wo/kernel", PS("tp", None, ("fsdp", "sp"))),
                    # mlp
                    ("feed_forward/w1/kernel", PS(("fsdp", "sp"), None, "tp")),
                    ("feed_forward/w2/kernel", PS("tp", None, ("fsdp", "sp"))),
                    ("feed_forward/w3/kernel", PS(("fsdp", "sp"), None, "tp")),
                    # layer norms
                    ("attention_norm/kernel", PS(None, None)),
                    ("ffn_norm/kernel", PS(None, None)),
                    # output head
                    ("transformer/ln_f/kernel", PS(None)),
                    ("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
                    ('.*', PS(None)),
                )
            else:
                raise ValueError(f"Invalid scan_axis {scan_axis}")
        else:
            return (
                # embeddings
                ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
                # atention
                ("attention/(wq|wk|wv)/kernel", PS(("fsdp", "sp"), "tp")),
                ("attention/wo/kernel", PS("tp", ("fsdp", "sp"))),
                # mlp
                ("feed_forward/w1/kernel", PS(("fsdp", "sp"), "tp")),
                ("feed_forward/w2/kernel", PS("tp", ("fsdp", "sp"))),
                ("feed_forward/w3/kernel", PS(("fsdp", "sp"), "tp")),
                # layer norms
                ("attention_norm/kernel", PS(None)),
                ("ffn_norm/kernel", PS(None)),
                # output head
                ("transformer/ln_f/kernel", PS(None)),
                ("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
                ('.*', PS(None)),
            )

    @staticmethod
    def get_weight_decay_exclusions():
        return tuple()

    @staticmethod
    def get_frozen_param_exclusions(freeze_base):
        if freeze_base:
            return ("vte", "vision_head")
        else:
            return tuple()

    @staticmethod
    def rng_keys():
        return ('params', 'dropout')

    @classmethod
    def load_config(cls, path):
        if path in LLAMA_STANDARD_CONFIGS:
            return cls.from_dict(LLAMA_STANDARD_CONFIGS[path])
        load_type, load_path = path.split('::', 1)
        if load_type == 'pickle':
            return cls.from_dict(load_pickle(load_path)['llama_config'])
        elif load_type == 'json':
            with open_file(load_path, 'r') as fin:
                raw_config = fin.read()
            return cls.from_dict(json.loads(raw_config))
        else:
            raise ValueError(f'Unsupported load config type: {load_type}')


remat = nn_partitioning.remat

logger = logging.get_logger(__name__)


class RMSNorm(nn.Module):
    dim: int
    eps: float=1e-6
    dtype: jnp.dtype=jnp.float32
    param_dtype: jnp.dtype=jnp.float32

    def setup(self) -> None:
        self.weight = self.param(
            'kernel',
            nn.initializers.ones,
            (self.dim,),
            self.param_dtype,
        )

    def _norm(self, x: jnp.ndarray) -> jnp.ndarray:
        return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps)

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = x.astype(jnp.promote_types(self.dtype, jnp.float32))
        output = self._norm(x).astype(self.dtype)
        weight = jnp.asarray(self.weight, self.dtype)
        return output * weight


def precompute_freqs_cis(dim: int, max_position_embedding: int, theta: float=10000.0, dtype: jnp.dtype=jnp.float32) -> jnp.ndarray:
    freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
    t = np.arange(max_position_embedding) # type: ignore
    freqs = np.outer(t, freqs).astype(dtype)  # type: ignore
    sin, cos = np.sin(freqs), np.cos(freqs)
    freqs_cis = np.complex64(cos + 1j * sin)
    return jnp.asarray(freqs_cis)


def apply_rotary_emb(
    xq: jnp.ndarray,
    xk: jnp.ndarray,
    freqs_cis: jnp.ndarray,
    dtype: jnp.dtype=jnp.float32,
) -> Tuple[jnp.ndarray, jnp.ndarray]:

    reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
    reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)

    xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
    xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])

    # add head dim
    freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:]))

    xq_out = xq_ * freqs_cis
    xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1)

    xk_out = xk_ * freqs_cis
    xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1)

    return xq_out.astype(dtype), xk_out.astype(dtype)


class FlaxLLaMAAttention(nn.Module):
    config: LLaMAConfig
    dtype: jnp.dtype=jnp.float32
    param_dtype: jnp.dtype=jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]]=None

    def setup(self):
        config = self.config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads

        self.wq = nn.Dense(
            config.num_attention_heads*self.head_dim,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            precision=self.precision,
        )
        self.wk = nn.Dense(
            config.num_attention_heads*self.head_dim,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            precision=self.precision,
        )
        self.wv = nn.Dense(
            config.num_attention_heads*self.head_dim,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            precision=self.precision,
        )
        self.wo = nn.Dense(
            config.hidden_size,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            precision=self.precision,
        )

        self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)

        self.causal_mask = make_causal_mask(jnp.ones((1, config.max_sequence_length), dtype="bool"), dtype="bool")

        self.freqs_cis = precompute_freqs_cis(
            self.head_dim,
            config.max_sequence_length,
            theta=config.theta,
            dtype=self.dtype,
        )

    def _split_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))

    def _merge_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))

    @nn.compact
    def _concatenate_to_cache(self, key, value, query, attention_mask):
        # detect if we're initializing by absence of existing cache data.
        is_initialized = self.has_variable("cache", "cached_key")
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
        cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))

        if is_initialized:
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
            # update key, value caches with our new 1d spatial slices
            cur_index = cache_index.value
            if query.shape[1] == 1:
                mesh = LLaMAConfig.get_jax_mesh(self.config.mesh_dim)
                def fn(cached_key, cached_value, key, value, cur_index):
                    assert key.shape[1] == 1 and value.shape[1] == 1, (key.shape, value.shape)
                    sp_size = max_length // mesh.shape['sp']
                    axis_index = jax.lax.axis_index('sp')
                    cur_index = cur_index - axis_index * sp_size
                    key, value = jax.lax.cond(
                        jnp.logical_and(cur_index >= 0, cur_index < sp_size),
                        lambda: (
                            cached_key.at[:, cur_index].set(key[:, -1]),
                            cached_value.at[:, cur_index].set(value[:, -1]),
                        ),
                        lambda: (cached_key, cached_value),
                    )
                    return key, value
                fn = shard_map(
                    fn, mesh=mesh,
                    in_specs=(
                        PS(('dp', 'fsdp'), 'sp', 'tp', None),
                        PS(('dp', 'fsdp'), 'sp', 'tp', None),
                        PS(('dp', 'fsdp'), None, 'tp', None),
                        PS(('dp', 'fsdp'), None, 'tp', None),
                        PS()
                    ),
                    out_specs=(
                        PS(('dp', 'fsdp'), 'sp', 'tp', None),
                        PS(('dp', 'fsdp'), 'sp', 'tp', None)
                    ),
                    check_rep=False
                )
                key, value = fn(cached_key.value, cached_value.value, key, value, cur_index)
            else:
                indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
                key = lax.dynamic_update_slice(cached_key.value, key, indices)
                value = lax.dynamic_update_slice(cached_value.value, value, indices)
            cached_key.value = key
            cached_value.value = value
            num_updated_cache_vectors = query.shape[1]
            cache_index.value = cache_index.value + num_updated_cache_vectors
        return key, value, attention_mask

    def __call__(
        self,
        hidden_states,
        attention_mask,
        segment_ids,
        position_ids,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):
        xq, xk, xv = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)

        if xq.shape[1] == 1:
            xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), None, "tp"))
        else:
            xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), "sp", "tp"))
        xk = with_sharding_constraint(xk, PS(("dp", "fsdp"), "sp", "tp"))
        xv = with_sharding_constraint(xv, PS(("dp", "fsdp"), "sp", "tp"))

        xq = self._split_heads(xq)
        xk = self._split_heads(xk)
        xv = self._split_heads(xv)

        freqs_cis = jnp.take(self.freqs_cis, position_ids, axis=0)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype)

        dropout_rng = None
        if not deterministic and self.config.attn_pdrop > 0.0:
            dropout_rng = self.make_rng("dropout")

        if self.config.scan_attention and xq.shape[1] > max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size):
            # attention mask without nxn materlization, blockwise_attn will handle the rest
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

            if self.has_variable("cache", "cached_key") or init_cache:
                xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask)

            # transform boolean mask into float mask
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
            )
            attn_weights = None
            ring_attention_sharded = shard_map(
                partial(
                    ringattention,
                    axis_name="sp",
                    float32_logits=True,
                    cache_idx=None,
                    blockwise_kwargs=dict(
                        causal_block_size=1,
                        deterministic=deterministic,
                        dropout_rng=dropout_rng,
                        attn_pdrop=self.config.attn_pdrop,
                        query_chunk_size=self.config.scan_query_chunk_size,
                        key_chunk_size=self.config.scan_key_chunk_size,
                        dtype=self.dtype,
                        policy=jax.checkpoint_policies.nothing_saveable,
                        precision=self.precision,
                        prevent_cse=not self.config.scan_layers,
                    )
                ),
                mesh=LLaMAConfig.get_jax_mesh(self.config.mesh_dim),
                in_specs=(
                    PS(("dp", "fsdp"), "sp", "tp", None),
                    PS(("dp", "fsdp"), "sp", "tp", None),
                    PS(("dp", "fsdp"), "sp", "tp", None),
                    PS(("dp", "fsdp"), None, None, None),
                    PS(("dp", "fsdp"), None),
                ),
                out_specs=PS(("dp", "fsdp"), "sp", "tp", None),
                check_rep=False
            )
            attn_output = ring_attention_sharded(xq, xk, xv, attention_bias, segment_ids)
            attn_output = with_sharding_constraint(attn_output, PS(("dp", "fsdp"), "sp", "tp", None))
        else:
            query_length, key_length = xq.shape[1], xk.shape[1]

            if self.has_variable("cache", "cached_key"):
                mask_shift = self.variables["cache"]["cache_index"]
                max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
                causal_mask = jnp.arange(max_decoder_length)[None] <= (jnp.arange(query_length) + mask_shift)[:, None]
                causal_mask = causal_mask[None, None]
                segment_mask = None
            else:
                causal_mask = self.causal_mask[:, :, :query_length, :key_length]
                if segment_ids is not None:
                    segment_mask = segment_ids[:, :, None] == segment_ids[:, None, :]
                    segment_mask = segment_mask[:, None]
                else:
                    segment_mask = None

            batch_size = hidden_states.shape[0]
            causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])

            attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
            attention_mask = combine_masks(attention_mask, causal_mask, segment_mask)

            # During fast autoregressive decoding, we feed one position at a time,
            # and cache the keys and values step by step.
            if self.has_variable("cache", "cached_key") or init_cache:
                xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask)

            q_sp_dim = None if xq.shape[1] == 1 else 'sp'
            attn_weights = None
            ring_attention_sharded = shard_map(
                partial(ringattention_inference, axis_name="sp"), mesh=LLaMAConfig.get_jax_mesh(self.config.mesh_dim),
                in_specs=(
                    PS(("dp", "fsdp"), q_sp_dim, "tp", None),
                    PS(("dp", "fsdp"), "sp", "tp", None),
                    PS(("dp", "fsdp"), "sp", "tp", None),
                    PS(("dp", "fsdp"), None, q_sp_dim, None)
                ),
                out_specs=PS(("dp", "fsdp"), q_sp_dim, "tp", None),
                check_rep=False
            )
            attn_output = ring_attention_sharded(
                xq, xk, xv, attention_mask
            )

        attn_output = self._merge_heads(attn_output)
        attn_output = self.wo(attn_output)
        attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
        outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
        return outputs


class FlaxLLaMAMLP(nn.Module):
    config: LLaMAConfig
    dtype: jnp.dtype=jnp.float32
    param_dtype: jnp.dtype=jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]]=None

    def setup(self) -> None:
        config = self.config

        self.w1 = nn.Dense(
            config.intermediate_size,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            precision=self.precision,
        )
        self.w2 = nn.Dense(
            config.hidden_size,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            precision=self.precision,
        )
        self.w3 = nn.Dense(
            config.intermediate_size,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            precision=self.precision,
        )
        self.dropout = nn.Dropout(rate=self.config.resid_pdrop)

    def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
        x = self.w2(nn.silu(self.w1(x)) * self.w3(x))
        x = self.dropout(x, deterministic=deterministic)
        return x


class FlaxLLaMABlock(nn.Module):
    config: LLaMAConfig
    dtype: jnp.dtype=jnp.float32
    param_dtype: jnp.dtype=jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]]=None

    def setup(self) -> None:
        attention_module = FlaxLLaMAAttention
        mlp_module = FlaxLLaMAMLP
        if self.config.scan_mlp:
            mlp_module = remat(
                mlp_module, static_argnums=(1,),
                policy=jax.checkpoint_policies.nothing_saveable,
                prevent_cse=not self.config.scan_layers,
            )
        self.attention = attention_module(
            self.config,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision,
        )
        self.feed_forward = mlp_module(
            self.config,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            precision=self.precision,
        )
        self.attention_norm = RMSNorm(
            self.config.hidden_size,
            eps=self.config.rms_norm_eps,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
        )
        self.ffn_norm = RMSNorm(
            self.config.hidden_size,
            eps=self.config.rms_norm_eps,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
        )

    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        segment_ids=None,
        position_ids=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):
        attn_outputs = self.attention(
            self.attention_norm(hidden_states),
            attention_mask,
            segment_ids,
            position_ids,
            deterministic,
            init_cache,
            output_attentions,
        )
        attn_output = attn_outputs[0]
        hidden_states = hidden_states + attn_output

        feed_forward_input = self.ffn_norm(hidden_states)

        if self.config.scan_mlp and hidden_states.shape[1] >= self.config.scan_mlp_chunk_size:
            feed_forward_hidden_states = blockwise_feedforward(
                self.feed_forward,
                feed_forward_input,
                self.config.scan_mlp_chunk_size,
                pre_remat=True,
            )
        else:
            feed_forward_hidden_states = self.feed_forward(feed_forward_input, deterministic)
        feed_forward_hidden_states = with_sharding_constraint(feed_forward_hidden_states, PS(("dp", "fsdp"), None, "tp"))

        hidden_states = hidden_states + feed_forward_hidden_states

        outputs = hidden_states
        if self.config.scan_layers:
            outputs = (outputs, None)
        return outputs


class FlaxLLaMAPreTrainedModel(FlaxPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = LLaMAConfig
    base_model_prefix = "transformer"
    module_class: nn.Module = None

    def __init__(
        self,
        config: LLaMAConfig,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)
        segment_ids = None
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        if self.config.add_cross_attention:
            encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
            encoder_attention_mask = attention_mask
            module_init_outputs = self.module.init(
                rngs,
                input_ids,
                attention_mask,
                segment_ids,
                position_ids,
                encoder_hidden_states,
                encoder_attention_mask,
                return_dict=False,
            )
        else:
            module_init_outputs = self.module.init(rngs, input_ids, attention_mask, segment_ids, position_ids, return_dict=False)

        random_params = module_init_outputs["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    def init_cache(self, batch_size, max_length):
        r"""
        Args:
            batch_size (`int`):
                batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
            max_length (`int`):
                maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
                cache.
        """
        # init input variables to retrieve cache
        input_ids = jnp.ones((batch_size, max_length))
        attention_mask = jnp.ones_like(input_ids)
        segment_ids = None
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, segment_ids, position_ids, return_dict=False, init_cache=True
        )
        return init_variables["cache"]

    @add_start_docstrings_to_model_forward("")
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        segment_ids=None,
        position_ids=None,
        params: dict = None,
        past_key_values: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        batch_size, sequence_length = input_ids.shape

        if position_ids is None:
            if past_key_values is not None:
                raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")

            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        if attention_mask is None:
            attention_mask = jnp.ones((batch_size, sequence_length))

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        inputs = {"params": params or self.params}

        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            segment_ids,
            jnp.array(position_ids, dtype="i4"),
            not train,
            False,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
            mutable=mutable,
        )

        # add updated cache to model output
        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]

        return outputs


class FlaxLLaMABlockCollection(nn.Module):
    config: LLaMAConfig
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype=jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]]=None

    @nn.compact
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        segment_ids=None,
        position_ids=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        all_attentions = () if output_attentions else None
        all_hidden_states = () if output_hidden_states else None

        block = FlaxLLaMABlock
        if self.config.scan_layers:
            initializing = self.is_mutable_collection('params')
            params_spec = (
                self.config.param_scan_axis if initializing else
                nn_partitioning.ScanIn(self.config.param_scan_axis))
            cache_spec = 0
            hidden_states, _ = nn.scan(
                block,
                variable_axes={
                    'params': params_spec,
                    'cache': cache_spec,
                    'intermediates': 0
                },
                split_rngs={
                    'params': True,
                    'dropout': True
                },
                in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast),
                length=self.config.num_hidden_layers,
                metadata_params={nn.PARTITION_NAME: 'scan_decoder_layer'},
                )(self.config, name='scan_decoder', dtype=self.dtype, param_dtype=self.param_dtype,)(
                    hidden_states,
                    attention_mask,
                    segment_ids,
                    position_ids,
                    deterministic,
                    init_cache,
                    output_attentions,
                )
        else:
            blocks = [
                block(
                    self.config,
                    name=str(i),
                    dtype=self.dtype,
                    param_dtype=self.param_dtype,
                ) for i in range(self.config.num_hidden_layers)
            ]
            for block in blocks:
                if output_hidden_states:
                    all_hidden_states += (hidden_states,)

                layer_outputs = block(
                    hidden_states,
                    attention_mask,
                    segment_ids,
                    position_ids,
                    deterministic,
                    init_cache,
                    output_attentions,
                )
                hidden_states = layer_outputs

                if output_attentions:
                    all_attentions += (layer_outputs[1],)

        outputs = (hidden_states, all_hidden_states, all_attentions)

        return outputs


class FlaxLLaMAModule(nn.Module):
    config: LLaMAConfig
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype=jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]]=None

    def setup(self):
        self.embed_dim = self.config.hidden_size

        self.wte = nn.Embed(
            self.config.vocab_size,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
            param_dtype=self.param_dtype,
        )
        self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
        self.h = FlaxLLaMABlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision)
        self.ln_f = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        segment_ids,
        position_ids,
        deterministic=True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        input_embeds = self.wte(input_ids.astype("i4"))
        assert input_embeds.shape[1] <= self.config.max_sequence_length, f"Input sequence length {input_embeds.shape[1]} larger than max supported sequence length {self.config.max_sequence_length}"

        hidden_states = self.dropout(input_embeds, deterministic=deterministic)

        outputs = self.h(
            hidden_states,
            attention_mask,
            segment_ids=segment_ids,
            position_ids=position_ids,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        hidden_states = self.ln_f(hidden_states)

        if output_hidden_states:
            all_hidden_states = outputs[1] + (hidden_states,)
            outputs = (hidden_states, all_hidden_states) + outputs[2:]
        else:
            outputs = (hidden_states,) + outputs[1:]

        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=outputs[1],
            attentions=outputs[-1],
        )

class FlaxLLaMAForCausalLMModule(nn.Module):
    config: LLaMAConfig
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype=jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]]=None

    def setup(self):
        self.transformer = FlaxLLaMAModule(self.config, dtype=self.dtype)
        self.lm_head = nn.Dense(
            self.config.vocab_size,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            precision=self.precision,
        )

    def __call__(
        self,
        input_ids,
        attention_mask=None,
        segment_ids=None,
        position_ids=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        batch_size, seq_length = input_ids.shape
        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)
        if position_ids is None:
            position_ids = jnp.arange(seq_length, dtype=jnp.int32)[None].repeat(batch_size, axis=0)
        outputs = self.transformer(
            input_ids,
            attention_mask,
            segment_ids,
            position_ids,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]

        if self.config.tie_word_embeddings:
            shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
            lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
        else:
            lm_logits = self.lm_head(hidden_states)

        if not return_dict:
            return (lm_logits,) + outputs[1:]

        return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)


@add_start_docstrings("", "")
class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel):
    module_class = FlaxLLaMAForCausalLMModule

    def prepare_inputs_for_generation(
        self, input_ids, max_length,
        attention_mask: Optional[jax.Array] = None,
    ):
        # initializing the cache
        batch_size, seq_length = input_ids.shape

        past_key_values = self.init_cache(batch_size, max_length)
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        if attention_mask is not None:
            position_ids = attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
        else:
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        return model_kwargs


================================================
FILE: lwm/train.py
================================================
import pprint
import os
from functools import partial

from tqdm import tqdm, trange
import numpy as np
from absl.app import run
import absl.logging as logging
import tux

import jax
import flax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
from flax.training.train_state import TrainState
from transformers import AutoTokenizer

from lwm.data import DatasetFactory
from tux import (
    JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
    cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
    set_random_seed, average_metrics, get_mask,
    make_shard_and_gather_fns, with_sharding_constraint, define_flags_with_default,
    OptimizerFactory, StreamingCheckpointer
)
from lwm.llama import LLaMAConfig, FlaxLLaMAForCausalLMModule
from lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLMModule


FLAGS, FLAGS_DEF = define_flags_with_default(
    modality='text',
    use_data_sharded_loader=True,
    seed=42,
    mesh_dim='1,-1,1,1',
    dtype='fp32',
    total_steps=10000,
    load_llama_config='',
    update_llama_config='',
    load_checkpoint='',
    load_dataset_state='',
    log_freq=50,
    save_model_freq=0,
    save_milestone_freq=0,
    eval_steps=0,
    tokenizer='LargeWorldModel/LWM-Text-1M',
    train_dataset=DatasetFactory.get_default_config(),
    eval_dataset=DatasetFactory.get_default_config(),
    optimizer=OptimizerFactory.get_default_config(),
    checkpointer=StreamingCheckpointer.get_default_config(),
    llama=VideoLLaMAConfig.get_default_config(),
    logger=tux.WandBLogger.get_default_config(),
    log_all_worker=False,
    jax_distributed=JaxDistributedConfig.get_default_config(),
    autoresume=False,
)


def main(argv):
    JaxDistributedConfig.initialize(FLAGS.jax_distributed)
    variant = tux.get_user_flags(FLAGS, FLAGS_DEF)
    flags_config_dict = tux.user_flags_to_config_dict(FLAGS, FLAGS_DEF)

    logger = tux.WandBLogger(
        config=FLAGS.logger,
        variant=variant,
        enable=FLAGS.log_all_worker or (jax.process_index() == 0),
    )
    set_random_seed(FLAGS.seed)

    if jax.process_index() == 0:
        output_dir = logger.output_dir
    else:
        output_dir = os.path.join(logger.output_dir, logger.experiment_id)

    if FLAGS.modality == 'text':
        config_cls = LLaMAConfig
        llama_cls = FlaxLLaMAForCausalLMModule
    elif FLAGS.modality == 'vision,text':
        config_cls = VideoLLaMAConfig
        llama_cls = FlaxVideoLLaMAForCausalLMModule
    else:
        raise ValueError(f"Unsupported modality: {FLAGS.modality}")

    mesh = config_cls.get_jax_mesh(FLAGS.mesh_dim)
    node_info = config_cls.get_ranks_and_size(mesh)

    tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer)
    dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer, node_info=node_info)
    if FLAGS.autoresume and tux.check_exists(output_dir):
        logging.info('Found existing output. Resuming dataset from latest checkpoint...')
        resume_path = f"{output_dir}/dataset.pkl"
        dataset.load_state_dict(tux.load_pickle(resume_path))
    elif FLAGS.load_dataset_state != '':
        dataset.load_state_dict(tux.load_pickle(FLAGS.load_dataset_state))

    if FLAGS.eval_steps > 0:
        eval_dataset = DatasetFactory.load_dataset(
            FLAGS.eval_dataset, dataset.tokenizer
        )
        eval_iterator = iter(eval_dataset)

    seq_length = dataset.seq_length

    if FLAGS.load_llama_config != '':
        llama_config = config_cls.load_config(FLAGS.load_llama_config)
        updates = config_cls(**FLAGS.llama)
        llama_config.update(dict(
            scan_attention=updates.scan_attention,
            scan_mlp=updates.scan_mlp,
            scan_query_chunk_size=updates.scan_query_chunk_size,
            scan_key_chunk_size=updates.scan_key_chunk_size,
            scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
            scan_layers=updates.scan_layers,
            param_scan_axis=updates.param_scan_axis,
        ))
    else:
        llama_config = config_cls(**FLAGS.llama)

    if FLAGS.update_llama_config != '':
        llama_config.update(dict(eval(FLAGS.update_llama_config)))

    llama_config.update(dict(
        bos_token_id=dataset.tokenizer.bos_token_id,
        eos_token_id=dataset.tokenizer.eos_token_id,
    ))
    if llama_config.vocab_size < dataset.vocab_size:
        llama_config.update(dict(vocab_size=dataset.vocab_size))
    llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))

    model = llama_cls(
        llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
    )

    optimizer, optimizer_info = OptimizerFactory.get_optimizer(
        FLAGS.optimizer,
        get_mask(config_cls.get_weight_decay_exclusions()),
        None,
    )

    def create_trainstate_from_params(params):
        return TrainState.create(params=params, tx=optimizer, apply_fn=None)

    def init_fn(rng):
        rng_generator = JaxRNG(rng)
        batch = 512
        if FLAGS.modality == 'text':
            params = model.init(
                input_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32),
                position_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32),
                attention_mask=jnp.ones((batch, seq_length), dtype=jnp.int32),
                rngs=rng_generator(llama_config.rng_keys()),
            )
        elif FLAGS.modality == 'vision,text':
            params = model.init(
                input_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32),
                vision_masks=jnp.zeros((batch, seq_length), dtype=bool),
                position_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32),
                attention_mask=jnp.ones((batch, seq_length), dtype=jnp.int32),
                rngs=rng_generator(llama_config.rng_keys()),
            )
        else:
            raise ValueError(f"Unsupported modality: {FLAGS.modality}")
        return TrainState.create(params=params, tx=optimizer, apply_fn=None)

    def train_step(train_state, rng, batch):
        rng_generator = JaxRNG(rng)
        batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
        def loss_and_accuracy(params):
            if FLAGS.modality == 'text':
                logits = model.apply(
                    params,
                    batch['input_tokens'],
                    deterministic=False,
                    rngs=rng_generator(llama_config.rng_keys()),
                ).logits
                loss, acc = cross_entropy_loss_and_accuracy(
                    logits,
                    batch['target_tokens'],
                    batch['loss_masks']
                )
                metrics = dict(acc=acc)
                return loss, metrics
            elif FLAGS.modality == 'vision,text':
                vision_logits, text_logits = model.apply(
                    params,
                    batch['input_tokens'],
                    batch['input_vision_masks'],
                    deterministic=False,
                    rngs=rng_generator(llama_config.rng_keys()),
                ).logits
                vision_loss, vision_acc = cross_entropy_loss_and_accuracy(
                    vision_logits,
                    jnp.where(batch['target_vision_masks'], batch['target_tokens'], 0),
                    batch['loss_masks'] * batch['target_vision_masks']
                )
                text_loss, text_acc = cross_entropy_loss_and_accuracy(
                    text_logits,
                    jnp.where(batch['target_vision_masks'], 0, batch['target_tokens']),
                    batch['loss_masks'] * (1.0 - batch['target_vision_masks'])
                )
                loss = 0.5 * (vision_loss + text_loss)

                metrics = dict(
                    vision_loss=vision_loss,
                    vision_acc=vision_acc,
                    text_loss=text_loss,
                    text_acc=text_acc,
                )
            else:
                raise ValueError(f"Unsupported modality: {FLAGS.modality}")
            return loss, metrics
        grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
        (loss, loss_metrics), grads = grad_fn(train_state.params)
        train_state = train_state.apply_gradients(grads=grads)
        metrics = dict(
            loss=loss,
            learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
            param_norm=global_norm(train_state.params),
            gradient_norm=global_norm(grads),
            **loss_metrics
        )
        return train_state, rng_generator(), metrics

    def eval_step(train_state, rng, batch):
        rng_generator = JaxRNG(rng)
        batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
        if FLAGS.modality == 'text':
            logits = model.apply(
                train_state.params,
                batch['input_tokens'],
                deterministic=True,
                rngs=rng_generator(llama_config.rng_keys()),
            ).logits
            loss, acc = cross_entropy_loss_and_accuracy(
                logits,
                batch['target_tokens'],
                batch['loss_masks']
            )
            metrics = dict(
                eval_loss=loss,
                eval_acc=acc,
            )
        elif FLAGS.modality == 'vision,text':
            vision_logits, text_logits = model.apply(
                train_state.params,
                batch['input_tokens'],
                batch['input_vision_masks'],
                deterministic=True,
                rngs=rng_generator(llama_config.rng_keys()),
            ).logits
            vision_loss, vision_acc = cross_entropy_loss_and_accuracy(
                vision_logits,
                jnp.where(batch['target_vision_masks'], batch['target_tokens'], 0),
                batch['loss_masks'] * batch['target_vision_masks']
            )
            text_loss, text_acc = cross_entropy_loss_and_accuracy(
                text_logits,
                jnp.where(batch['target_vision_masks'], 0, batch['target_tokens']),
                batch['loss_masks'] * (1.0 - batch['target_vision_masks'])
            )
            loss = 0.5 * (vision_loss + text_loss)
            metrics = dict(
                eval_loss=loss,
                eval_vision_accuracy=vision_acc,
                eval_vision_loss=vision_loss,
                eval_text_accuracy=text_acc,
                eval_text_loss=text_loss,
            )
        return rng_generator(), metrics

    train_state_shapes = jax.eval_shape(init_fn, next_rng())
    train_state_partition = match_partition_rules(
        config_cls.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), train_state_shapes
    )

    shard_fns, gather_fns = make_shard_and_gather_fns(
        train_state_partition, train_state_shapes
    )
    checkpointer = StreamingCheckpointer(
        FLAGS.checkpointer, logger.output_dir,
        enable=jax.process_index() == 0,
    )

    sharded_init_fn = pjit(
        init_fn,
        in_shardings=PS(),
        out_shardings=train_state_partition
    )

    sharded_create_trainstate_from_params = pjit(
        create_trainstate_from_params,
        in_shardings=(train_state_partition.params, ),
        out_shardings=train_state_partition,
        donate_argnums=(0, ),
    )

    if FLAGS.use_data_sharded_loader:
        batch_spec = PS(('dp', 'fsdp'), 'sp')
    else:
        batch_spec = PS()
    sharded_train_step = pjit(
        train_step,
        in_shardings=(train_state_partition, PS(), batch_spec),
        out_shardings=(train_state_partition, PS(), PS()),
        donate_argnums=(0, 1),
    )

    sharded_eval_step = pjit(
        eval_step,
        in_shardings=(train_state_partition, PS(), PS()),
        out_shardings=(PS(), PS()),
        donate_argnums=(1,),
    )

    def save_checkpoint(train_state, milestone=False):
        step = int(jax.device_get(train_state.step))
        metadata = dict(
            step=step,
            variant=variant,
            flags=flags_config_dict,
            llama_config=llama_config.to_dict(),
        )
        checkpointer.save_all(
            train_state=train_state,
            gather_fns=gather_fns,
            metadata=metadata,
            dataset=dataset.get_state_dict(),
            milestone=milestone,
        )

    with mesh:
        train_state, restored_params = None, None

        if FLAGS.autoresume and tux.check_exists(output_dir):
            logging.info('Found existing output. Resuming model from latest checkpoint...')
            resume_path = f"trainstate::{output_dir}/streaming_train_state"
            train_state, restored_params = checkpointer.load_trainstate_checkpoint(
                resume_path, train_state_shapes, shard_fns, max_buffer_size=32 * 2 ** 30
            )
        elif FLAGS.load_checkpoint != '':
            train_state, restored_params = checkpointer.load_trainstate_checkpoint(
                FLAGS.load_checkpoint, train_state_shapes, shard_fns, max_buffer_size=32 * 2 ** 30
            )

        if train_state is None and restored_params is None:
            # Initialize from scratch
            train_state = sharded_init_fn(next_rng())
        elif train_state is None and restored_params is not None:
            # Restore from params but initialize train_state
            train_state = sharded_create_trainstate_from_params(flax.core.unfreeze(restored_params))
            del restored_params

        start_step = int(jax.device_get(train_state.step))

        if FLAGS.save_model_freq > 0:
            save_checkpoint(train_state)

        sharded_rng = next_rng()

        step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
        for step, (batch, dataset_metrics) in zip(step_counter, dataset):
            train_state, sharded_rng, metrics = sharded_train_step(
                train_state, sharded_rng, batch
            )
            if step % FLAGS.log_freq == 0:
                if FLAGS.eval_steps > 0:
                    eval_metric_list = []
                    for _ in range(FLAGS.eval_steps):
                        eval_batch, _ = next(eval_iterator)
                        sharded_rng, eval_metrics = sharded_eval_step(
                            train_state, sharded_rng, eval_batch
                        )
                        eval_metrics = jax.device_get(eval_metrics)
                        eval_metric_list.append(eval_metrics)
                    metrics.update(average_metrics(eval_metric_list))

                log_metrics = {"step": step}
                log_metrics.update(metrics)
                log_metrics.update(dataset_metrics)
                log_metrics = jax.device_get(log_metrics)
                logger.log(log_metrics)
                tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")

            if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
                save_checkpoint(train_state, milestone=True)
            elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
                save_checkpoint(train_state)

        if FLAGS.save_model_freq > 0:
            save_checkpoint(train_state)


if __name__ == "__main__":
    run(main)


================================================
FILE: lwm/vision_chat.py
================================================
from absl.app import run
import math
from tqdm import tqdm
from PIL import Image
import decord
from functools import cached_property
import numpy as np
import jax
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
from transformers import GenerationConfig, AutoTokenizer
from tux import (
    define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,
    set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng,
    match_partition_rules, make_shard_and_gather_fns,
    with_sharding_constraint, tree_apply, open_file
)
from lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM
from lwm.vqgan import VQGAN


FLAGS, FLAGS_DEF = define_flags_with_default(
    prompt="",
    input_file="",
    vqgan_checkpoint="",
    temperature=0.2,
    max_n_frames=8,
    seed=1234,
    mesh_dim='1,-1,1,1',
    dtype='fp32',
    load_llama_config='',
    update_llama_config='',
    load_checkpoint='',
    tokenizer='LargeWorldModel/LWM-Text-1M',
    llama=VideoLLaMAConfig.get_default_config(),
    jax_distributed=JaxDistributedConfig.get_default_config(),
)


class Sampler:
    def __init__(self):
        self.mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
        self.vqgan = VQGAN(FLAGS.vqgan_checkpoint, replicate=False)
        self.prefix_tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer, truncation_side='left', padding_side='left')
        self.tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer)
        self.n_tokens_per_frame = 257
        self.min_buffer_size = 256
        self.sharded_rng = next_rng()
        self._load_model()

    @property
    def block_size(self):
        return max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) * self.mesh.shape['sp']

    @property
    def data_dim(self):
        return self.mesh.shape['dp'] * self.mesh.shape['fsdp']

    def _process_frame(self, image, size):
        width, height = image.size
        if width < height:
            new_width = size
            new_height = int(size * height / width)
        else:
            new_height = size
            new_width = int(size * width / height)
        image = image.resize((new_width, new_height))

        left = (new_width - size) / 2
        top = (new_height - size) / 2
        right = (new_width + size) / 2
        bottom = (new_height + size) / 2
        image = image.crop((left, top, right, bottom))
        return np.array(image, dtype=np.float32) / 127.5 - 1

    def _read_process_vision(self, path, max_n_frames):
        f = open_file(path, 'rb')
        if path.endswith('.png') or path.endswith('.jpg'):
            image = Image.open(f).convert('RGB')
            vision = self._process_frame(image, 256)[None]
        else:
            vr = decord.VideoReader(f, ctx=decord.cpu(0))
            duration = len(vr)
            if duration <= max_n_frames:
                frame_id_list = list(range(duration))
            else:
                frame_id_list = np.linspace(0, duration - 1, max_n_frames, dtype=int).tolist()
            video = vr.get_batch(frame_id_list).asnumpy()
            vision = np.stack([self._process_frame(Image.fromarray(frame), 256) for frame in video])

        B = 1
        encodings = []
        for i in range(0, len(vision), 1):
            v = vision[i:i + B]
            if len(v) % B == 0:
                n_pad = 0
            else:
                n_pad = B - len(v) % B
            v = np.pad(v, ((n_pad, 0), (0, 0), (0, 0), (0, 0)))
            enc = jax.device_get(self.vqgan.encode(v))[1].astype(int)
            enc = enc[n_pad:]
            for t in range(len(enc)):
                encodings.extend(enc[t].reshape(-1).tolist())
                if t == len(enc) - 1:
                    encodings.append(8193)
                else:
                    encodings.append(8192)
        return encodings

    def construct_input(self, prompts, max_n_frames):
        max_input_length = max_n_frames * self.n_tokens_per_frame + self.min_buffer_size
        max_input_length = int(math.ceil(max_input_length / self.block_size) * self.block_size)

        vision_start = self.tokenizer.encode('<vision>')
        vision_end = self.tokenizer.encode('</vision>')

        input_ids = np.zeros((len(prompts), max_input_length), dtype=int)
        vision_masks = np.zeros((len(prompts), max_input_length), dtype=bool)
        attention_mask = np.zeros((len(prompts), max_input_length), dtype=int)
        for i, prompt in enumerate(tqdm(prompts)):
            vision = self._read_process_vision(prompt['input_path'], max_n_frames)
            text_1 = self.tokenizer.encode(f"<s>You are a helpful assistant. USER: {prompt['question']}\n")
            tail = self.tokenizer.encode(" ASSISTANT:")

            tokens, vm = [], []
            tokens.extend(text_1)
            vm.extend([False] * len(text_1))
            tokens.extend(vision_start)
            vm.extend([False] * len(vision_start))
            tokens.extend(vision)
            vm.extend([True] * len(vision))
            tokens.extend(vision_end)
            vm.extend([False] * len(vision_end))
            tokens.extend(tail)
            vm.extend([False] * len(tail))
            assert len(tokens) < max_input_length, (len(tokens), max_input_length)
            assert len(tokens) == len(vm)
            input_ids[i, -len(tokens):] = tokens
            vision_masks[i, -len(tokens):] = vm
            attention_mask[i, -len(tokens):] = 1
        return {
            'input_ids': input_ids,
            'vision_masks': vision_masks,
            'attention_mask': attention_mask
        }


    def _load_model(self):
        if FLAGS.load_llama_config != '':
            llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config)
            updates = VideoLLaMAConfig(**FLAGS.llama)
            llama_config.update(dict(
                scan_attention=updates.scan_attention,
                scan_mlp=updates.scan_mlp,
                scan_query_chunk_size=updates.scan_query_chunk_size,
                scan_key_chunk_size=updates.scan_key_chunk_size,
                scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
                scan_layers=updates.scan_layers,
                param_scan_axis=updates.param_scan_axis,
            ))
        else:
            llama_config = VideoLLaMAConfig(**FLAGS.llama)

        if FLAGS.update_llama_config != '':
            llama_config.update(dict(eval(FLAGS.update_llama_config)))

        llama_config.update(dict(
            bos_token_id=self.tokenizer.bos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        ))
        llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))
        self.config = llama_config

        self.model = FlaxVideoLLaMAForCausalLM(
            llama_config,
            input_shape=(512, self.block_size),
            seed=FLAGS.seed,
            _do_init=False,
            dtype=get_float_dtype_by_name(FLAGS.dtype),
        )

        with jax.default_device(jax.devices("cpu")[0]):
            _, self.params = StreamingCheckpointer.load_trainstate_checkpoint(
                    FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
            )
        self.model_ps = match_partition_rules(
            VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params
        )
        shard_fns, _ = make_shard_and_gather_fns(
            self.model_ps, get_float_dtype_by_name(FLAGS.dtype)
        )

        with self.mesh:
            self.params = tree_apply(shard_fns, self.params)

    @cached_property
    def _forward_generate(self):
        def fn(params, rng, batch):
            batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
            rng_generator = JaxRNG(rng)
            output = self.model.generate(
                batch['input_ids'],
                vision_masks=batch['vision_masks'],
                attention_mask=batch['attention_mask'],
                params=params['params'],
                prng_key=rng_generator(),
                generation_config=GenerationConfig(
                    max_new_tokens=self.block_size,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    temperature=FLAGS.temperature,
                    do_sample=True,
                )
            ).sequences[:, batch['input_ids'].shape[1]:]
            return output, rng_generator()
        return pjit(
            fn,
            in_shardings=(self.model_ps, PS(), PS()),
            out_shardings=(PS(), PS())
        )

    def __call__(self, prompts, max_n_frames):
        batch = self.construct_input(prompts, max_n_frames)
        with self.mesh:
            output, self.sharded_rng = self._forward_generate(
                self.params, self.sharded_rng, batch
            )
            output = jax.device_get(output)
        output_text = []
        for text in list(self.tokenizer.batch_decode(output, skip_special_tokens=True)):
            if self.tokenizer.eos_token in text:
                text = text.split(self.tokenizer.eos_token, maxsplit=1)[0]
            output_text.append(text)
        return output_text

def main(argv):
    assert FLAGS.prompt != ''
    assert FLAGS.input_file != ''

    JaxDistributedConfig.initialize(FLAGS.jax_distributed)
    set_random_seed(FLAGS.seed)

    prompts = [{'input_path': FLAGS.input_file, 'question': FLAGS.prompt}]
    sampler = Sampler()
    output = sampler(prompts, FLAGS.max_n_frames)[0]
    print(f"Question: {FLAGS.prompt}\nAnswer: {output}")

if __name__ == "__main__":
    run(main)


================================================
FILE: lwm/vision_generation.py
================================================
from absl.app import run
from tqdm import tqdm
import imageio
import numpy as np
from PIL import Image
from transformers import GenerationConfig, AutoTokenizer
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
from tux import (
    define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,
    set_random_seed, get_float_dtype_by_name, JaxRNG,
    match_partition_rules, make_shard_and_gather_fns,
    with_sharding_constraint, tree_apply, next_rng
)
from lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM
from lwm.vqgan import VQGAN


FLAGS, FLAGS_DEF = define_flags_with_default(
    prompt='Fireworks over the city',
    output_file='',
    temperature_image=1.0,
    temperature_video=1.0,
    top_k_image=8192,
    top_k_video=100,
    cfg_scale_image=1.0,
    cfg_scale_video=1.0,
    vqgan_checkpoint='',
    n_frames=1,
    seed=1234,
    mesh_dim='1,-1,1,1',
    dtype='fp32',
    load_llama_config='',
    update_llama_config='',
    load_checkpoint='',
    tokenizer='LargeWorldModel/LWM-Text-1M',
    llama=VideoLLaMAConfig.get_default_config(),
    jax_distributed=JaxDistributedConfig.get_default_config(),
)


def main(argv):
    assert FLAGS.output_file != ''
    if FLAGS.output_file.endswith('mp4'):
        assert FLAGS.n_frames > 1
    elif FLAGS.output_file.endswith('png') or FLAGS.output_file.endswith('jpg'):
        assert FLAGS.n_frames == 1
    else:
        raise ValueError(f"Unsupported output file extension: {FLAGS.output_file}")

    JaxDistributedConfig.initialize(FLAGS.jax_distributed)
    set_random_seed(FLAGS.seed)

    tokens_per_frame = 257
    vqgan = VQGAN(FLAGS.vqgan_checkpoint, replicate=False)
    mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
    tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer)
    prefix_tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer, truncation_side='left', padding_side='left')
    if FLAGS.load_llama_config != '':
        llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config)
        updates = VideoLLaMAConfig(**FLAGS.llama)
        llama_config.update(dict(
            scan_attention=updates.scan_attention,
            scan_mlp=updates.scan_mlp,
            scan_query_chunk_size=updates.scan_query_chunk_size,
            scan_key_chunk_size=updates.scan_key_chunk_size,
            scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
            scan_layers=updates.scan_layers,
            param_scan_axis=updates.param_scan_axis,
        ))
    else:
        llama_config = VideoLLaMAConfig(**FLAGS.llama)

    if FLAGS.update_llama_config != '':
        llama_config.update(dict(eval(FLAGS.update_llama_config)))

    llama_config.update(dict(
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    ))
    llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))

    with jax.default_device(jax.devices("cpu")[0]):
        _, params = StreamingCheckpointer.load_trainstate_checkpoint(
                FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
        )
        model = FlaxVideoLLaMAForCausalLM(
            llama_config,
            input_shape=(512, 8192),
            seed=FLAGS.seed,
            _do_init=False,
            dtype=get_float_dtype_by_name(FLAGS.dtype),
        )
        model_ps = match_partition_rules(
            VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), params
        )
        shard_fns, _ = make_shard_and_gather_fns(
            model_ps, get_float_dtype_by_name(FLAGS.dtype)
        )

        with mesh:
            params = tree_apply(shard_fns, params)

    def _forward_generate(params, rng, batch, n_tokens, cfg_scale, top_k, temperature):
        batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
        cfg_scales = jnp.ones((batch['input_ids'].shape[0] // 2,), dtype=jnp.float32) * cfg_scale
        cfg_scales = with_sharding_constraint(cfg_scales, PS(('dp', 'fsdp')))
        rng_generator = JaxRNG(rng)
        output = model.generate_vision(
            batch['input_ids'],
            cfg_scales,
            attention_mask=batch['attention_mask'],
            vision_masks=batch['vision_masks'],
            params=params['params'],
            prng_key=rng_generator(),
            generation_config=GenerationConfig(
                max_new_tokens=n_tokens,
                min_new_tokens=n_tokens,
                pad_token_id=tokenizer.pad_token_id,
                temperature=temperature,
                do_sample=True,
                top_k=top_k,
            )
        ).sequences[:, batch['input_ids'].shape[1]:]
        return output, rng_generator()
    _sharded_forward_generate = pjit(
        _forward_generate,
        in_shardings=(model_ps, PS(), PS()),
        out_shardings=(PS(), PS()),
        static_argnums=(3, 4, 5, 6)
    )

    # Generate an image or first frame (for video)
    def generate_first_frame(prompts, max_input_length):
        nonlocal sharded_rng
        uncond_prompts = ["<s><vision>"] * len(prompts)
        prompts = prompts + uncond_prompts
        inputs = prefix_tokenizer(
            prompts,
            padding='max_length',
            truncation=True,
            max_length=max_input_length,
            return_tensors='np'
        )
        batch = dict(
            input_ids=inputs.input_ids,
            attention_mask=inputs.attention_mask,
            vision_masks=np.zeros(inputs.input_ids.shape, dtype=bool),
        )
        with mesh:
            output, sharded_rng = _sharded_forward_generate(
                params, sharded_rng, batch,
                tokens_per_frame, FLAGS.cfg_scale_image,
                FLAGS.top_k_image, FLAGS.temperature_image
            )
            output = jax.device_get(output)
            output = np.split(output, 2, axis=0)[0]
        output = output.reshape(len(prompts) // 2, tokens_per_frame)
        image = vqgan.decode(output[:, :-1].reshape(-1, 16, 16))
        image = ((jax.device_get(image) + 1) * 127.5).astype(np.uint8)
        return output, image

    sharded_rng = next_rng()
    prompts = [FLAGS.prompt]
    entries = []
    for prompt in prompts:
        entries.append({
            'caption': prompt,
            'prompt': f"<s>You are a helpful assistant. USER: Generate an image of {prompt} ASSISTANT: <vision>",
        })

    B = 1
    images, image_encodings = [], []
    for i in tqdm(list(range(0, len(entries), B))):
        entries_i = entries[i:i + B]
        prompts = [entry['prompt'] for entry in entries_i]
        img_enc, img = generate_first_frame(prompts, max_input_length=128)
        image_encodings.extend(img_enc)
        images.extend(img)

    if FLAGS.n_frames == 1:
        image = images[0]
        Image.fromarray(image).save(FLAGS.output_file)
        return

    # Generate the rest of the video
    def generate_video_pred(prompts, images, max_input_length):
        nonlocal sharded_rng
        images = np.concatenate([images, images], axis=0)
        uncond_prompts = ["<s><vision>"] * len(prompts)
        prompts = prompts + uncond_prompts
        inputs = prefix_tokenizer(
            prompts,
            padding='max_length',
            truncation=True,
            max_length=max_input_length,
            return_tensors='np'
        )
        batch = dict(
            input_ids=np.concatenate([inputs.input_ids, images], axis=1),
            attention_mask=np.concatenate([inputs.attention_mask, np.ones(images.shape, dtype=inputs.attention_mask.dtype)], axis=1),
            vision_masks=np.concatenate([
                np.zeros(inputs.input_ids.shape, dtype=bool),
                np.ones(images.shape, dtype=bool)
            ], axis=1),
        )
        with mesh:
            output, sharded_rng = _sharded_forward_generate(
                params, sharded_rng, batch,
                (FLAGS.n_frames - 1) * tokens_per_frame, FLAGS.cfg_scale_video,
                FLAGS.top_k_video, FLAGS.temperature_video
            )
            output = jax.device_get(output)
            output = np.split(output, 2, axis=0)[0]
        output = output.reshape(len(prompts) // 2, FLAGS.n_frames - 1, tokens_per_frame)
        output = np.concatenate([images[:len(prompts) // 2, None], output], axis=1)
        output = output[:, :, :-1].reshape(-1, FLAGS.n_frames, 16, 16)
        vision = []
        for v in output:
            v = vqgan.decode(v)
            v = ((jax.device_get(v) + 1) * 127.5).astype(np.uint8)
            vision.append(v)
        return vision

    new_entries = []
    for img_enc, entry in zip(image_encodings, entries):
        new_entries.append({
            'caption': entry['caption'],
            'prompt': f"<s>You are a helpful assistant. USER: Generate a video of {entry['caption']} ASSISTANT: <vision>",
            'image': np.array(img_enc, dtype=np.int32),
        })
    entries = new_entries

    B = 1
    videos = []
    for i in tqdm(list(range(0, len(entries), B))):
        entries_i = entries[i:i + B]
        prompts = [entry['prompt'] for entry in entries_i]
        images = np.array([entry['image'] for entry in entries_i], dtype=np.int32)
        videos.extend(generate_video_pred(prompts, images, max_input_length=128))

    video = videos[0]
    writer = imageio.get_writer(FLAGS.output_file, fps=4)
    for frame in video:
        writer.append_data(frame)
    writer.close()

    print('done')

if __name__ == "__main__":
    run(main)


================================================
FILE: lwm/vision_llama.py
================================================
from typing import Any, Dict, List, Optional, Tuple, Union
import json
import warnings
import copy

import jax
import jax.numpy as jnp
from jax import lax
from jax.sharding import PartitionSpec as PS
import flax.linen as nn
from flax.core.frozen_dict import unfreeze, freeze
from flax.traverse_util import flatten_dict, unflatten_dict

from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
from transformers.generation.flax_utils import SampleState, FlaxLogitsProcessorList, FlaxSampleOutput, logger
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers import GenerationConfig

from tux import load_pickle, open_file
from lwm.llama import LLaMAConfig, LLAMA_STANDARD_CONFIGS, FlaxLLaMABlockCollection, RMSNorm


VIDEO_LLAMA_STANDARD_CONFIGS = LLAMA_STANDARD_CONFIGS


class VideoLLaMAConfig(LLaMAConfig):
    model_type = "video_llama"

    def __init__(self, vision_vocab_size=8448, tie_vision_embeddings=False, sample_mode='all', **kwargs):
        super().__init__(**kwargs)
        self.vision_vocab_size = vision_vocab_size # 8192 + 256
        self.tie_vision_embeddings = tie_vision_embeddings
        self.sample_mode = sample_mode

    @staticmethod
    def get_partition_rules(scan_layers=False, scan_axis=0):
        """Parition rules are orderd, so that the beginning rules match first."""
        if scan_layers:
            if scan_axis == 0:
                return (
                    # embeddings
                    ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
                    ("transformer/vte/embedding", PS("tp", ("fsdp", "sp"))),
                    # atention
                    ("attention/(wq|wk|wv)/kernel", PS(None, ("fsdp", "sp"), "tp")),
                    ("attention/wo/kernel", PS(None, "tp", ("fsdp", "sp"))),
                    # mlp
                    ("feed_forward/w1/kernel", PS(None, ("fsdp", "sp"), "tp")),
                    ("feed_forward/w2/kernel", PS(None, "tp", ("fsdp", "sp"))),
                    ("feed_forward/w3/kernel", PS(None, ("fsdp", "sp"), "tp")),
                    # layer norms
                    ("attention_norm/kernel", PS(None, None)),
                    ("ffn_norm/kernel", PS(None, None)),
                    # output head
                    ("transformer/ln_f/kernel", PS(None)),
                    ("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
                    ("vision_head/kernel", PS(("fsdp", "sp"), "tp")),
                    ('.*', PS(None)),
                )
            elif scan_axis == 1:
                return (
                    # embeddings
                    ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
                    ("transformer/vte/embedding", PS("tp", ("fsdp", "sp"))),
                    # atention
                    ("attention/(wq|wk|wv)/kernel", PS(("fsdp", "sp"), None, "tp")),
                    ("attention/wo/kernel", PS("tp", None, ("fsdp", "sp"))),
                    # mlp
                    ("feed_forward/w1/kernel", PS(("fsdp", "sp"), None, "tp")),
                    ("feed_forward/w2/kernel", PS("tp", None, ("fsdp", "sp"))),
                    ("feed_forward/w3/kernel", PS(("fsdp", "sp"), None, "tp")),
                    # layer norms
                    ("attention_norm/kernel", PS(None, None)),
                    ("ffn_norm/kernel", PS(None, None)),
                    # output head
                    ("transformer/ln_f/kernel", PS(None)),
                    ("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
                    ("vision_head/kernel", PS(("fsdp", "sp"), "tp")),
                    ('.*', PS(None)),
                )
            else:
                raise ValueError(f"Invalid scan_axis {scan_axis}")
        else:
            return (
                # embeddings
                ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
                ("transformer/vte/embedding", PS("tp", ("fsdp", "sp"))),
                # atention
                ("attention/(wq|wk|wv)/kernel", PS(("fsdp", "sp"), "tp")),
                ("attention/wo/kernel", PS("tp", ("fsdp", "sp"))),
                # mlp
                ("feed_forward/w1/kernel", PS(("fsdp", "sp"), "tp")),
                ("feed_forward/w2/kernel", PS("tp", ("fsdp", "sp"))),
                ("feed_forward/w3/kernel", PS(("fsdp", "sp"), "tp")),
                # layer norms
                ("attention_norm/kernel", PS(None)),
                ("ffn_norm/kernel", PS(None)),
                # output head
                ("transformer/ln_f/kernel", PS(None)),
                ("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
                ("vision_head/kernel", PS(("fsdp", "sp"), "tp")),
                ('.*', PS(None)),
            )

    @classmethod
    def load_config(cls, path):
        if path in VIDEO_LLAMA_STANDARD_CONFIGS:
            return cls.from_dict(VIDEO_LLAMA_STANDARD_CONFIGS[path])
        load_type, load_path = path.split('::', 1)
        if load_type == 'pickle':
            return cls.from_dict(load_pickle(load_path)['llama_config'])
        elif load_type == 'json':
            with open_file(load_path, 'r') as fin:
                raw_config = fin.read()
            return cls.from_dict(json.loads(raw_config))
        else:
            raise ValueError(f'Unsupported load config type: {load_type}')


class FlaxVideoLLaMAPreTrainedModel(FlaxPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = VideoLLaMAConfig
    base_model_prefix = "transformer"
    module_class: nn.Module = None

    def __init__(
        self,
        config: VideoLLaMAConfig,
        input_shape: Tuple = (4, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    def init_cache(self, batch_size, max_length):
        # init input variables to retrieve cache
        input_ids = jnp.ones((batch_size, max_length))
        attention_mask = jnp.ones_like(input_ids)
        segment_ids = jnp.zeros_like(input_ids)
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
        vision_masks = jnp.ones((batch_size, max_length), dtype=bool)

        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False, init_cache=True
        )
        return init_variables["cache"]

    def init_weights(self, rng, input_shape, params=None):
        # init input tensors
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)
        vision_masks = jnp.ones(input_ids.shape, dtype=bool)
        segment_ids = jnp.zeros_like(input_ids)
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        random_params = self.module.init(rngs, input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False)["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    @add_start_docstrings_to_model_forward("")
    def __call__(
        self,
        input_ids,
        vision_masks,
        attention_mask=None,
        segment_ids=None,
        position_ids=None,
        params: dict = None,
        past_key_values: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        batch_size, sequence_length = input_ids.shape

        if position_ids is None:
            if past_key_values is not None:
                raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")

            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        if attention_mask is None:
            attention_mask = jnp.ones((batch_size, sequence_length))

        if segment_ids is None:
            segment_ids = jnp.zeros((batch_size, sequence_length))

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        inputs = {"params": params or self.params}

        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),
            jnp.array(vision_masks, dtype="f4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(segment_ids, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            not train,
            False,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
            mutable=mutable,
        )

        # add updated cache to model output
        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]

        return outputs


class FlaxVideoLLaMAModule(nn.Module):
    config: VideoLLaMAConfig
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype=jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]]=None

    def setup(self):
        self.embed_dim = self.config.hidden_size

        self.vte = nn.Embed(
            self.config.vision_vocab_size,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
            param_dtype=self.param_dtype,
        )

        self.wte = nn.Embed(
            self.config.vocab_size,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
            param_dtype=self.param_dtype,
        )
        self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
        self.h = FlaxLLaMABlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision)
        self.ln_f = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype)

    def __call__(
        self,
        input_ids,
        vision_masks,
        attention_mask,
        segment_ids,
        position_ids,
        deterministic=True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        input_ids = input_ids.astype("i4")

        if input_ids.shape[1] == 1:
            if self.config.sample_mode == 'text':
                input_embeds = self.wte(input_ids)
            elif self.config.sample_mode == 'vision':
                input_embeds = self.vte(input_ids)
            elif self.config.sample_mode == 'all':
                raise NotImplementedError
            else:
                raise ValueError(f"Invalid sample_mode: {self.config.sample_mode}")
        else:
            input_text_embeds = self.wte(jnp.where(vision_masks, 0, input_ids))
            input_vision_embeds = self.vte(jnp.where(vision_masks, input_ids, 0))
            vision_masks = vision_masks[..., None].astype("f4") # 1 is vision, 0 is text
            input_embeds = input_text_embeds * (1 - vision_masks) + input_vision_embeds * vision_masks

        hidden_states = self.dropout(input_embeds, deterministic=deterministic)

        outputs = self.h(
            hidden_states,
            attention_mask,
            segment_ids,
            position_ids=position_ids,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        hidden_states = self.ln_f(hidden_states)

        if output_hidden_states:
            all_hidden_states = outputs[1] + (hidden_states,)
            outputs = (hidden_states, all_hidden_states) + outputs[2:]
        else:
            outputs = (hidden_states,) + outputs[1:]

        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=outputs[1],
            attentions=outputs[-1],
        )


class FlaxVideoLLaMAForCausalLMModule(nn.Module):
    config: VideoLLaMAConfig
    dtype: jnp.dtype = jnp.float32
    param_dtype: jnp.dtype=jnp.float32
    precision: Optional[Union[jax.lax.Precision, str]]=None

    def setup(self):
        self.transformer = FlaxVideoLLaMAModule(self.config, dtype=self.dtype)
        self.vision_head = nn.Dense(
            self.config.vision_vocab_size,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            precision=self.precision,
        )
        self.lm_head = nn.Dense(
            self.config.vocab_size,
            dtype=self.dtype,
            param_dtype=self.param_dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            precision=self.precision,
        )

    def __call__(
        self,
        input_ids,
        vision_masks,
        attention_mask=None,
        segment_ids=None,
        position_ids=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        batch_size, seq_length = input_ids.shape
        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)
        if segment_ids is None:
            segment_ids = jnp.zeros_like(input_ids)
        if position_ids is None:
            position_ids = jnp.broadcast_to(
                jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
                (batch_size, seq_length)
            )


        outputs = self.transformer(
            input_ids,
            vision_masks,
            attention_mask,
            segment_ids,
            position_ids,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]

        if self.config.tie_vision_embeddings:
            shared_kernel = self.transformer.variables["params"]["vte"]["embedding"].T
            vision_logits = self.vision_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
        else:
            vision_logits = self.vision_head(hidden_states)

        if self.config.tie_word_embeddings:
            shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
            lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
        else:
            lm_logits = self.lm_head(hidden_states)

        if self.config.sample_mode == 'all':
            if not return_dict:
                return (vision_logits, lm_logits,) + outputs[1:]

            return FlaxCausalLMOutput(logits=(vision_logits, lm_logits), hidden_states=outputs.hidden_states, attentions=outputs.attentions)
        elif self.config.sample_mode == 'vision':
            if not return_dict:
                return (vision_logits,) + outputs[1:]

            return FlaxCausalLMOutput(logits=vision_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
        elif self.config.sample_mode == 'text':
            if not return_dict:
                return (lm_logits,) + outputs[1:]

            return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
        else:
            raise ValueError(f"Invalid sample_mode: {self.config.sample_mode}")



@add_start_docstrings("", "")
class FlaxVideoLLaMAForCausalLM(FlaxVideoLLaMAPreTrainedModel):
    module_class = FlaxVideoLLaMAForCausalLMModule

    def prepare_inputs_for_generation(
        self, input_ids, max_length, attention_mask: Optional[jax.Array] = None, vision_masks = None
    ):
        # initializing the cache
        batch_size, seq_length = input_ids.shape

        past_key_values = self.init_cache(batch_size, max_length)
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        if attention_mask is not None:
            position_ids = attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
        else:
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
            "vision_masks": vision_masks
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        return {
            "past_key_values":  model_outputs.past_key_values,
            "position_ids": model_kwargs["position_ids"][:, -1:] + 1,
            "attention_mask": model_kwargs["attention_mask"],
            "vision_masks": model_kwargs["vision_masks"]
        }

    def _sample_vision(
        self,
        input_ids: None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        prng_key: Optional[jnp.ndarray] = None,
        logits_processor: Optional[FlaxLogitsProcessorList] = None,
        logits_warper: Optional[FlaxLogitsProcessorList] = None,
        cfg_scales: jnp.ndarray = 1.0,
        trace: bool = True,
        params: Optional[Dict[str, jnp.ndarray]] = None,
        model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
    ):
        # init values
        max_length = max_length if max_length is not None else self.generation_config.max_length
        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
        prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)

        batch_size, cur_len = input_ids.shape
        initial_len = cur_len

        eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
        pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
        cur_len = jnp.array(cur_len)

        # per batch-item holding current token in loop.
        sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
        sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))

        # per batch-item state bit indicating if sentence has finished.
        is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)

        # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
        # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
        model = self.decode if self.config.is_encoder_decoder else self

        # initialize model specific kwargs
        model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)

        # initialize state
        state = SampleState(
            cur_len=cur_len,
            sequences=sequences,
            running_token=input_ids,
            is_sent_finished=is_sent_finished,
            prng_key=prng_key,
            model_kwargs=model_kwargs,
        )

        def sample_search_cond_fn(state):
            """state termination condition fn."""
            has_reached_max_length = state.cur_len == max_length
            all_sequence_finished = jnp.all(state.is_sent_finished)
            finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
            return ~finish_generation

        def sample_search_body_fn(state):
            """state update fn."""
            prng_key, prng_key_next = jax.random.split(state.prng_key)
            model_outputs = model(state.running_token, params=params, **state.model_kwargs)

            logits = model_outputs.logits[:, -1]
            cond_logits, uncond_logits = jnp.split(logits, 2, axis=0)
            logits = uncond_logits + cfg_scales[:, None] * (cond_logits - uncond_logits)

            # apply min_length, ...
            logits = logits_processor(state.sequences, logits, state.cur_len)
            # apply top_p, top_k, temperature
            logits = logits_warper(logits, logits, state.cur_len)

            next_token = jax.random.categorical(prng_key, logits, axis=-1)
            next_token = jax.lax.cond(
                (state.cur_len - initial_len + 1) % 257 == 0,
                lambda: jnp.full_like(next_token, 8192),
                lambda: next_token
            )
            next_token = jnp.concatenate([next_token, next_token], axis=0)

            #next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
            next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
            next_token = next_token[:, None]

            next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
            next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)

            return SampleState(
                cur_len=state.cur_len + 1,
                sequences=next_sequences,
                running_token=next_token,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
                prng_key=prng_key_next,
            )

        # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
        if input_ids.shape[1] > 1:
            state = sample_search_body_fn(state)

        if not trace:
            state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)
        else:
            state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)

        return FlaxSampleOutput(sequences=state.sequences)

    def generate_vision(
        self,
        input_ids: jnp.ndarray,
        cfg_scales: jnp.ndarray,
        generation_config: Optional[GenerationConfig] = None,
        prng_key: Optional[jnp.ndarray] = None,
        trace: bool = True,
        params: Optional[Dict[str, jnp.ndarray]] = None,
        logits_processor: Optional[FlaxLogitsProcessorList] = None,
        **kwargs,
    ):
        # Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
        self._validate_model_class()

        # priority: `generation_config` argument > `model.generation_config` (the default generation config)
        if generation_config is None:
            # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
            # two conditions must be met
            # 1) the generation config must have been created from the model config (`_from_model_config` field);
            # 2) the generation config must have seen no modification since its creation (the hash is the same).
            if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
                self.generation_config
            ):
                new_generation_config = GenerationConfig.from_model_config(self.config)
                if new_generation_config != self.generation_config:
                    warnings.warn(
                        "You have modified the pretrained model configuration to control generation. This is a"
                        " deprecated strategy to control generation and will be removed soon, in a future version."
                        " Please use and modify the model generation configuration (see"
                        " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
                    )
                    self.generation_config = new_generation_config
            generation_config = self.generation_config

        generation_config = copy.deepcopy(generation_config)
        model_kwargs = generation_config.update(**kwargs)  # All unused kwargs must be model kwargs
        generation_config.validate()
        self._validate_model_kwargs(model_kwargs.copy())

        logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList()

        # set init values
        prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)

        if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
            if model_kwargs.get("attention_mask") is None:
                logger.warning(
                    "The attention mask and the pad token id were not set. As a consequence, you may observe "
                    "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
                )
            eos_token_id = generation_config.eos_token_id
            if isinstance(eos_token_id, list):
                eos_token_id = eos_token_id[0]
            logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
            generation_config.pad_token_id = eos_token_id

        if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder:
            raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")

        # decoder-only models should use left-padding for generation (can't be checked with `trace=True`)
        if not self.config.is_encoder_decoder and not trace:
            if (
                generation_config.pad_token_id is not None
                and jnp.sum(input_ids[:, -1] == generation_config.pad_token_id) > 0
            ):
                logger.warning(
                    "A decoder-only architecture is being used, but right-padding was detected! For correct "
                    "generation results, please set `padding_side='left'` when initializing the tokenizer."
                )

        batch_size = input_ids.shape[0]

        if self.config.is_encoder_decoder:
            # add encoder_outputs to model_kwargs
            if model_kwargs.get("encoder_outputs") is None:
                model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
            # prepare decoder_input_ids for generation
            input_ids = self._prepare_decoder_input_ids_for_generation(
                batch_size,
                decoder_start_token_id=generation_config.decoder_start_token_id,
                bos_token_id=generation_config.bos_token_id,
                model_kwargs=model_kwargs,
            )

        # Prepare `max_length` depending on other stopping criteria.
        input_ids_seq_length = input_ids.shape[-1]
        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
        if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
            # 20 is the default max_length of the generation config
            warnings.warn(
                f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
                "to control the generation length.  recommend setting `max_new_tokens` to control the maximum length of the generation.",
                UserWarning,
            )
        elif generation_config.max_new_tokens is not None:
            if not has_default_max_length and generation_config.max_length is not None:
                logger.warning(
                    f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
                    f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
                    "Please refer to the documentation for more information. "
                    "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
                )
            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length

        if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
            raise ValueError(
                f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger than"
                f" the maximum length ({generation_config.max_length})"
            )
        if input_ids_seq_length >= generation_config.max_length:
            input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
            logger.warning(
                f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
                f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
                " increasing`max_new_tokens`."
            )

        logits_processor = self._get_logits_processor(
            generation_config=generation_config,
            input_ids_seq_length=input_ids_seq_length,
            logits_processor=logits_processor,
        )

        if not generation_config.do_sample and generation_config.num_beams == 1:
            raise NotImplementedError
        elif generation_config.do_sample and generation_config.num_beams == 1:
            logits_warper = self._get_logits_warper(generation_config=generation_config)
            return self._sample_vision(
                input_ids,
                generation_config.max_length,
                generation_config.pad_token_id,
                generation_config.eos_token_id,
                prng_key,
                logits_warper=logits_warper,
                logits_processor=logits_processor,
                cfg_scales=cfg_scales,
                trace=trace,
                params=params,
                model_kwargs=model_kwargs,
            )
        elif not generation_config.do_sample and generation_config.num_beams > 1:
            raise NotImplementedError
        else:
            raise NotImplementedError("`Beam sampling is currently not implemented.")


================================================
FILE: lwm/vqgan.py
================================================
from typing import Optional
from functools import cached_property, partial
import pickle
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax import jax_utils
from transformers.configuration_utils import PretrainedConfig
from ml_collections import ConfigDict
from tux import function_args_to_config, open_file


class VQGAN:
    def __init__(self, vqgan_checkpoint, replicate=False):
        assert vqgan_checkpoint != ''
        self.replicate = replicate
        self.config = VQGANConfig.get_default_config()
        self.params = pickle.load(open_file(vqgan_checkpoint, 'rb'))
        if replicate:
            self.params = jax_utils.replicate(self.params)
        else:
            self.params = jax.jit(lambda x: x)(self.params)
        self.model = VQGANModel(self.config)

    def _wrap_fn(self, fn):
        if self.replicate:
            return jax.pmap(fn, devices=jax.local_devices())
        else:
            return jax.jit(fn)
    
    @cached_property
    def _encode(self):
        def fn(pixel_values, params):
            return self.model.apply(
                {'params': params}, 
                pixel_values,
                method=self.model.encode
            )
        return partial(self._wrap_fn(fn), params=self.params)
    
    @cached_property
    def _decode(self):
        def fn(encoding, params):
            return self.model.apply(
                {'params': params},
                encoding,
                method=self.model.decode
            )
        return partial(self._wrap_fn(fn), params=self.params)
    
    def encode(self, pixel_values):
        return self._encode(pixel_values)
    
    def decode(self, encoding):
        return self._decode(encoding)
    

class VQGANConfig(PretrainedConfig):
    model_type = "vqgan"

    def __init__(
        self,
        resolution=256,
        num_channels=3,
        hidden_channels=128,
        channel_mult=(1, 2, 2, 4, 6),
        num_res_blocks=2,
        attn_resolutions=(),
        no_attn_mid_block=True,
        z_channels=64,
        num_embeddings=8192,
        quantized_embed_dim=64,
        dropout=0.0,
        resample_with_conv=True,
        commitment_cost=0.25
    ):
        self.resolution = resolution
        self.num_channels = num_channels
        self.hidden_channels = hidden_channels
        self.channel_mult = channel_mult
        self.num_res_blocks = num_res_blocks
        self.attn_resolutions = attn_resolutions
        self.no_attn_mid_block = no_attn_mid_block
        self.z_channels = z_channels
        self.num_embeddings = num_embeddings
        self.quantized_embed_dim = quantized_embed_dim
        self.dropout = dropout
        self.resample_with_conv = resample_with_conv
        self.commitment_cost = commitment_cost
    
    @classmethod
    def get_default_config(cls, updates=None):
        config = function_args_to_config(cls.__init__)
        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        config.num_resolutions = len(config.channel_mult)
        return config
    
    @classmethod
    def load_config(cls, path):
        return cls.get_default_config(cls)

        
class VQGANModel(nn.Module):
    config: VQGANConfig

    def setup(self):
        self.encoder = Encoder(self.config)
        self.decoder = Decoder(self.config)
        self.quantize = VectorQuantizer(
            self.config.num_embeddings, self.config.quantized_embed_dim
        )
        self.quant_conv = nn.Conv(self.config.quantized_embed_dim, [1, 1])
        self.post_quant_conv = nn.Conv(self.config.z_channels, [1, 1])
    
    def encode(self, pixel_values):
        T = None
        if len(pixel_values.shape) == 5: # video
            T = pixel_values.shape[1]
            pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:])
        hidden_states = self.encoder(pixel_values)
        hidden_states = self.quant_conv(hidden_states)
        quantized_states, codebook_indices = self.quantize(hidden_states)
        if T is not None:
            quantized_states = quantized_states.reshape(-1, T, *quantized_states.shape[1:])
            codebook_indices = codebook_indices.reshape(-1, T, *codebook_indices.shape[1:])
        return quantized_states, codebook_indices

    def decode(self, encoding, is_codebook_indices=True):
        if is_codebook_indices:
            encoding = self.quantize(None, encoding)
        T = None
        if len(encoding.shape) == 5:
            T = encoding.shape[1]
            encoding = encoding.reshape(-1, *encoding.shape[2:])
        hidden_states = self.post_quant_conv(encoding)
        reconstructed_pixel_values = self.decoder(hidden_states)
        if T is not None:
            reconstructed_pixel_values = reconstructed_pixel_values.reshape(-1, T, *reconstructed_pixel_values.shape[1:])
        return jnp.clip(reconstructed_pixel_values, -1, 1)
    
    def __call__(self, pixel_values):
        encoding = self.encode(pixel_values)[1]
        recon = self.decode(encoding)
        return recon
    

class Encoder(nn.Module):
    config: VQGANConfig
    
    @nn.compact
    def __call__(self, pixel_values):
        assert pixel_values.shape[1] == pixel_values.shape[2] == self.config.resolution, pixel_values.shape
        hidden_states = nn.Conv(self.config.hidden_channels, [3, 3])(pixel_values)
        for i_level in range(self.config.num_resolutions):
            hidden_states = DownsamplingBlock(self.config, i_level)(hidden_states)
        hidden_states = MidBlock(
            self.config, self.config.no_attn_mid_block, self.config.dropout
        )(hidden_states)
        hidden_states = nn.GroupNorm()(hidden_states)
        hidden_states = nn.silu(hidden_states)
        hidden_states = nn.Conv(self.config.z_channels, [3, 3])(hidden_states)
        return hidden_states

        
class Decoder(nn.Module):
    config: VQGANConfig

    @nn.compact
    def __call__(self, hidden_states):
        hidden_states = nn.Conv(
            self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1],
            [3, 3]
        )(hidden_states)
        hidden_states = MidBlock(
            self.config, self.config.no_attn_mid_block, self.config.dropout
        )(hidden_states)
        for i_level in reversed(range(self.config.num_resolutions)):
            hidden_states = UpsamplingBlock(self.config, i_level)(hidden_states)
        hidden_states = nn.GroupNorm()(hidden_states)
        hidden_states = nn.silu(hidden_states)
        hidden_states = nn.Conv(self.config.num_channels, [3, 3])(hidden_states)
        return hidden_states


class VectorQuantizer(nn.Module):
    n_e: int
    e_dim: int

    @nn.compact
    def __call__(self, z, encoding_indices=None):
        def quantize(encoding_indices):
            w = jax.device_put(embeddings)
            return w[(encoding_indices,)]
        embeddings = self.param(
            'embeddings',
            lambda rng, shape, dtype: jax.random.uniform(
                rng, shape, dtype, minval=-1.0 / self.n_e, maxval=1.0 / self.n_e
            ),
            [self.n_e, self.e_dim], jnp.float32
        )
        
        if encoding_indices is not None:
            return quantize(encoding_indices)

        z_flattened = z.reshape(-1, z.shape[-1])
        d = jnp.sum(z_flattened ** 2, axis=1, keepdims=True) + \
            jnp.sum(embeddings.T ** 2, axis=0, keepdims=True) - \
            2 * jnp.einsum('bd,nd->bn', z_flattened, embeddings)
        
        min_encoding_indices = jnp.argmin(d, axis=1)
        z_q = quantize(min_encoding_indices)
        z_q = jnp.reshape(z_q, z.shape)
        z_q = z + jax.lax.stop_gradient(z_q - z)

        encodings_one_hot = jax.nn.one_hot(min_encoding_indices, num_classes=self.n_e)
        assert len(encodings_one_hot.shape) == 2
        min_encoding_indices = jnp.reshape(min_encoding_indices, z.shape[:-1])

        return z_q, min_encoding_indices


class DownsamplingBlock(nn.Module):
    config: VQGANConfig
    block_idx: int

    @nn.compact
    def __call__(self, hidden_states):
        block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
        for _ in range(self.config.num_res_blocks):
            hidden_states = ResnetBlock(
                block_out, dropout_prob=self.config.dropout
            )(hidden_states) 
            if hidden_states.shape[1] in self.config.attn_resolutions:
                hidden_states = AttnBlock()(hidden_states)
        if self.block_idx != self.config.num_resolutions - 1:
            hidden_states = Downsample(self.config.resample_with_conv)(hidden_states)
        return hidden_states


class ResnetBlock(nn.Module):
    out_channels: Optional[int] = None
    use_conv_shortcut: bool = False
    dropout_prob: float = 0.0

    @nn.compact
    def __call__(self, hidden_states):
        out_channels = self.out_channels or hidden_states.shape[-1]
        residual = hidden_states
        hidden_states = nn.GroupNorm()(hidden_states)
        hidden_states = nn.silu(hidden_states)
        hidden_states = nn.Conv(out_channels, [3, 3])(hidden_states)
        hidden_states = nn.GroupNorm()(hidden_states)
        hidden_states = nn.silu(hidden_states)
        hidden_states = nn.Dropout(self.dropout_prob, deterministic=True)(hidden_states)
        hidden_states = nn.Conv(out_channels, [3, 3])(hidden_states)
        if out_channels != residual.shape[-1]:
            if self.use_conv_shortcut:
                residual = nn.Conv(out_channels, [3, 3])(residual)
            else:
                residual = nn.Conv(out_channels, [1, 1])(residual)
        return hidden_states + residual
        

class AttnBlock(nn.Module):
    @nn.compact
    def __call__(self, hidden_states):
        residual = hidden_states
        hidden_states = nn.GroupNorm()(hidden_states)
        query = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states)
        key = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states)
        value = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states)
        query, key, value = map(
            lambda x: x.reshape(x.shape[0], -1, x.shape[-1]),
            [query, key, value]
        )
        attn_weights = jnp.einsum("bqd,bkd->bqk", query, key)
        attn_weights *= hidden_states.shape[-1] ** -0.5
        attn_weights = jax.nn.softmax(attn_weights, axis=-1)
        hidden_states = jnp.einsum("bqk,bkd->bqd", attn_weights, value)
        hidden_states = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states)
        return hidden_states + residual

        
class Downsample(nn.Module):
    with_conv: bool
    
    @nn.compact
    def __call__(self, hidden_states):
        if self.with_conv:
            hidden_states = jnp.pad(
                hidden_states,
                [(0, 0), (0, 1), (0, 1), (0, 0)]
            )
            hidden_states = nn.Conv(
                hidden_states.shape[-1], [3, 3], 
                strides=[2, 2], 
                padding="VALID"
            )(hidden_states)
        else:
            hidden_states = nn.avg_pool(hidden_states, [2, 2], [2, 2])
        return hidden_states

        
class Upsample(nn.Module):
    with_conv: bool

    @nn.compact
    def __call__(self, hidden_states):
        B, H, W, C = hidden_states.shape
        hidden_states = jax.image.resize(
            hidden_states,
            (B, H * 2, W * 2, C),
            method="nearest"
        )
        if self.with_conv:
            hidden_states = nn.Conv(hidden_states.shape[-1], [3, 3])(hidden_states)
        return hidden_states


class UpsamplingBlock(nn.Module):
    config: VQGANConfig
    block_idx: int

    @nn.compact
    def __call__(self, hidden_states):
        block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
        for _ in range(self.config.num_res_blocks + 1):
            hidden_states = ResnetBlock(
                block_out, dropout_prob=self.config.dropout
            )(hidden_states)
            if hidden_states.shape[1] in self.config.attn_resolutions:
                hidden_states = AttnBlock()(hidden_states)
        if self.block_idx != 0:
            hidden_states = Upsample(self.config.resample_with_conv)(hidden_states)
        return hidden_states


class MidBlock(nn.Module):
    config: VQGANConfig
    no_attn: bool
    dropout: float

    @nn.compact
    def __call__(self, hidden_states):
        hidden_states = ResnetBlock(dropout_prob=self.dropout)(hidden_states)
        if not self.no_attn:
            hidden_states = AttnBlock()(hidden_states)
        hidden_states = ResnetBlock(dropout_prob=self.dropout)(hidden_states)
        return hidden_states


================================================
FILE: scripts/create_needle_data.py
================================================
import os
import argparse
import json
from tqdm import tqdm
from datasets import load_dataset

parser = argparse.ArgumentParser()
parser.add_argument("--output_path", type=str, default="data/pg19.jsonl")
args = parser.parse_args()

os.makedirs(os.path.dirname(args.output_path), exist_ok=True)

dset = load_dataset("pg19")["train"]
with open(args.output_path, "w") as f:
    for elem in tqdm(dset):
        data = {"text": elem["text"]}
        f.write(f"{json.dumps(data)}\n")

================================================
FILE: scripts/eval_needle.py
================================================
from absl.app import run
import time
import json
import math
import os
from tqdm import tqdm
import random
from functools import cached_property
import numpy as np
import jax
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
import gcsfs
import tiktoken
from transformers import GenerationConfig, AutoTokenizer
from tux import (
    define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,
    set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng,
    match_partition_rules, make_shard_and_gather_fns,
    with_sharding_constraint, tree_apply, open_file
)
from lwm.llama import LLaMAConfig, FlaxLLaMAForCausalLM


FLAGS, FLAGS_DEF = define_flags_with_default(
    haystack_file="",
    max_tokens_per_batch=2000000,
    output_file="results.json",
    context_lengths_min=1000,
    context_lengths_max=32000,
    n_context_length_intervals=3,
    n_document_depth_intervals=3,
    n_rounds=2,
    seed=1234,
    mesh_dim='1,-1,1,1',
    dtype='fp32',
    load_llama_config='',
    update_llama_config='',
    load_checkpoint='',
    tokenizer='LargeWorldModel/LWM-Text-1M',
    checkpointer=StreamingCheckpointer.get_default_config(),
    llama=LLaMAConfig.get_default_config(),
    jax_distributed=JaxDistributedConfig.get_default_config(),
)


class LLMNeedleHaystackTester:
    OURS_TEMPLATE = "You are a helpful assistant. USER: {context} {question} Don't give information outside the document or repeat your findings. Keep your response short and direct. ASSISTANT: "
    RANDOM_NEEDLE_CITIES  = [
        'Chicago', 'Yangon', 'Antananarivo', 'Colombo', 'Almaty', 'Sydney', 'Chicago', 'Mexico City',
        'Seattle', 'Lagos', 'Amsterdam', 'Belgrade', 'Cairo', 'Baghdad', 'Damascus', 'Kigali', 'Dakar',
        'Dakar', 'Sofia', 'Kigali', 'Victoria', 'Tashkent', 'Mumbai', 'Barcelona', 'Almaty', 'Amman',
        'Toronto', 'Bratislava', 'Johannesburg', 'Thimphu', 'Bangkok', 'Santiago', 'Cairo', 'San Francisco',
        'Lagos', 'Amsterdam', 'Paris', 'Rabat', 'Santiago', 'Copenhagen', 'Madrid', 'Kigali',
        'Ho Chi Minh City', 'Sarajevo', 'Delhi', 'Istanbul', 'Ho Chi Minh City', 'Khartoum', 'Helsinki',
        'Doha', 'Istanbul', 'Kuala Lumpur', 'Budapest', 'Shanghai', 'Moscow', 'Los Angeles', 'Oslo',
        'Johannesburg', 'Berlin', 'Bangalore', 'Tokyo', 'Melbourne', 'Barcelona', 'Chicago', 'Port Louis',
        'Lisbon', 'Nairobi', 'Kampala', 'Lima', 'Maputo', 'Vancouver', 'Dubai', 'Khartoum', 'Jakarta',
        'Madrid', 'Yerevan', 'Beirut', 'Athens', 'Chicago', 'Paris', 'Bucharest', 'Copenhagen', 'Brussels',
        'Damascus', 'Seattle', 'Los Angeles', 'Yerevan', 'Victoria', 'Tunis', 'Astana', 'Seoul',
        'Buenos Aires', 'Bangkok', 'Colombo', 'Brussels', 'Khartoum', 'Doha', 'San Francisco', 'Vienna', 'Jakarta'
    ]

    def __init__(self,
                 needle="",
                 haystack_file="",
                 retrieval_question="What is the special magic {} number?",
                 results_version = 1,
                 rnd_number_digits = 7,
                 context_lengths_min = 1000,
                 context_lengths_max = 126000,
                 context_lengths_num_intervals = 10,
                 document_depth_percent_min = 0,
                 document_depth_percent_max = 100,
                 document_depth_percent_intervals = 10,
                 document_depth_percent_interval_type = "linear",
                 save_results = False,
                 final_context_length_buffer = 200,
                 print_ongoing_status = True):
        needle="\nThe special magic {city} number is: {rnd_number}\n"
        self.needle = needle
        if not needle or not haystack_file or not retrieval_question:
            raise ValueError("Needle, haystack, and r
Download .txt
gitextract_q2zwva4c/

├── .gitignore
├── LICENSE
├── README.md
├── docs/
│   ├── data.md
│   └── sharding.md
├── gpu_requirements.txt
├── lwm/
│   ├── __init__.py
│   ├── data.py
│   ├── llama.py
│   ├── train.py
│   ├── vision_chat.py
│   ├── vision_generation.py
│   ├── vision_llama.py
│   └── vqgan.py
├── scripts/
│   ├── create_needle_data.py
│   ├── eval_needle.py
│   ├── eval_needle_multi.py
│   ├── run_eval_needle.sh
│   ├── run_eval_needle_multi.sh
│   ├── run_sample_image.sh
│   ├── run_sample_video.sh
│   ├── run_train_text.sh
│   ├── run_train_vision_text.sh
│   ├── run_vision_chat.sh
│   └── sample_pyt.py
└── tpu_requirements.sh
Download .txt
SYMBOL INDEX (210 symbols across 9 files)

FILE: lwm/data.py
  class DatasetFactory (line 16) | class DatasetFactory(object):
    method get_default_config (line 20) | def get_default_config(updates=None):
    method load_dataset (line 35) | def load_dataset(cls, config, tokenizer, **kwargs):
    method __init__ (line 51) | def __init__(self):
  class TextProcessor (line 55) | class TextProcessor(object):
    method get_default_config (line 58) | def get_default_config(updates=None):
    method __init__ (line 70) | def __init__(self, config, tokenizer):
    method __call__ (line 77) | def __call__(self, example, has_aux=False, add_bos_token=True, add_eos...
  class VisionTextProcessor (line 126) | class VisionTextProcessor(object):
    method get_default_config (line 128) | def get_default_config(updates=None):
    method __init__ (line 144) | def __init__(self, config, tokenizer):
    method __call__ (line 153) | def __call__(self, example, has_aux=False, add_bos_token=True, add_eos...
  class HuggingfaceDataset (line 242) | class HuggingfaceDataset(object):
    method get_default_config (line 248) | def get_default_config(updates=None):
    method __init__ (line 262) | def __init__(self, config, tokenizer, text_processor):
    method __iter__ (line 272) | def __iter__(self):
    method get_state_dict (line 305) | def get_state_dict(self):
    method load_state_dict (line 308) | def load_state_dict(self, state_dict):
    method seq_length (line 313) | def seq_length(self):
    method tokenizer (line 317) | def tokenizer(self):
    method text_processor (line 321) | def text_processor(self):
    method dataset (line 325) | def dataset(self):
    method vocab_size (line 329) | def vocab_size(self):
  class JsonDataset (line 333) | class JsonDataset(object):
    method get_default_config (line 339) | def get_default_config(updates=None):
    method __init__ (line 360) | def __init__(self, config, tokenizer, text_processor, node_info):
    method parse_json (line 370) | def parse_json(self, line):
    method json_iterator (line 380) | def json_iterator(self):
    method batched (line 398) | def batched(self, iterator, batch_size):
    method parallel_example_iterator (line 408) | def parallel_example_iterator(self):
    method __iter__ (line 434) | def __iter__(self):
    method _make_callback (line 510) | def _make_callback(self, v):
    method get_state_dict (line 513) | def get_state_dict(self):
    method load_state_dict (line 521) | def load_state_dict(self, state_dict):
    method seq_length (line 529) | def seq_length(self):
    method tokenizer (line 533) | def tokenizer(self):
    method text_processor (line 537) | def text_processor(self):
    method vocab_size (line 541) | def vocab_size(self):
  class JsonVisionDataset (line 545) | class JsonVisionDataset(object):
    method get_default_config (line 547) | def get_default_config(updates=None):
    method __init__ (line 568) | def __init__(self, config, tokenizer, text_processor, node_info):
    method parse_json (line 578) | def parse_json(self, line):
    method json_iterator (line 588) | def json_iterator(self):
    method batched (line 606) | def batched(self, iterator, batch_size):
    method parallel_example_iterator (line 616) | def parallel_example_iterator(self):
    method __iter__ (line 642) | def __iter__(self):
    method _iter_pad (line 651) | def _iter_pad(self):
    method _iter_no_pad (line 736) | def _iter_no_pad(self):
    method _make_callback (line 810) | def _make_callback(self, v):
    method get_state_dict (line 813) | def get_state_dict(self):
    method load_state_dict (line 821) | def load_state_dict(self, state_dict):
    method seq_length (line 829) | def seq_length(self):
    method tokenizer (line 833) | def tokenizer(self):
    method text_processor (line 837) | def text_processor(self):
    method vocab_size (line 841) | def vocab_size(self):

FILE: lwm/llama.py
  class LLaMAConfig (line 133) | class LLaMAConfig(PretrainedConfig):
    method __init__ (line 136) | def __init__(
    method get_default_config (line 193) | def get_default_config(cls, updates=None):
    method get_jax_mesh (line 202) | def get_jax_mesh(axis_dims):
    method get_ranks_and_size (line 206) | def get_ranks_and_size(mesh):
    method get_partition_rules (line 222) | def get_partition_rules(scan_layers=False, scan_axis=0):
    method get_weight_decay_exclusions (line 286) | def get_weight_decay_exclusions():
    method get_frozen_param_exclusions (line 290) | def get_frozen_param_exclusions(freeze_base):
    method rng_keys (line 297) | def rng_keys():
    method load_config (line 301) | def load_config(cls, path):
  class RMSNorm (line 320) | class RMSNorm(nn.Module):
    method setup (line 326) | def setup(self) -> None:
    method _norm (line 334) | def _norm(self, x: jnp.ndarray) -> jnp.ndarray:
    method __call__ (line 337) | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
  function precompute_freqs_cis (line 344) | def precompute_freqs_cis(dim: int, max_position_embedding: int, theta: f...
  function apply_rotary_emb (line 353) | def apply_rotary_emb(
  class FlaxLLaMAAttention (line 378) | class FlaxLLaMAAttention(nn.Module):
    method setup (line 384) | def setup(self):
    method _split_heads (line 434) | def _split_heads(self, hidden_states):
    method _merge_heads (line 437) | def _merge_heads(self, hidden_states):
    method _concatenate_to_cache (line 441) | def _concatenate_to_cache(self, key, value, query, attention_mask):
    method __call__ (line 494) | def __call__(
  class FlaxLLaMAMLP (line 623) | class FlaxLLaMAMLP(nn.Module):
    method setup (line 629) | def setup(self) -> None:
    method __call__ (line 658) | def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp....
  class FlaxLLaMABlock (line 664) | class FlaxLLaMABlock(nn.Module):
    method setup (line 670) | def setup(self) -> None:
    method __call__ (line 704) | def __call__(
  class FlaxLLaMAPreTrainedModel (line 747) | class FlaxLLaMAPreTrainedModel(FlaxPreTrainedModel):
    method __init__ (line 757) | def __init__(
    method init_weights (line 769) | def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, pa...
    method init_cache (line 806) | def init_cache(self, batch_size, max_length):
    method __call__ (line 827) | def __call__(
  class FlaxLLaMABlockCollection (line 898) | class FlaxLLaMABlockCollection(nn.Module):
    method __call__ (line 905) | def __call__(
  class FlaxLLaMAModule (line 982) | class FlaxLLaMAModule(nn.Module):
    method setup (line 988) | def setup(self):
    method __call__ (line 1002) | def __call__(
  class FlaxLLaMAForCausalLMModule (line 1049) | class FlaxLLaMAForCausalLMModule(nn.Module):
    method setup (line 1055) | def setup(self):
    method __call__ (line 1066) | def __call__(
  class FlaxLLaMAForCausalLM (line 1110) | class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel):
    method prepare_inputs_for_generation (line 1113) | def prepare_inputs_for_generation(
    method update_inputs_for_generation (line 1134) | def update_inputs_for_generation(self, model_outputs, model_kwargs):

FILE: lwm/train.py
  function main (line 59) | def main(argv):

FILE: lwm/vision_chat.py
  class Sampler (line 40) | class Sampler:
    method __init__ (line 41) | def __init__(self):
    method block_size (line 52) | def block_size(self):
    method data_dim (line 56) | def data_dim(self):
    method _process_frame (line 59) | def _process_frame(self, image, size):
    method _read_process_vision (line 76) | def _read_process_vision(self, path, max_n_frames):
    method construct_input (line 110) | def construct_input(self, prompts, max_n_frames):
    method _load_model (line 148) | def _load_model(self):
    method _forward_generate (line 197) | def _forward_generate(self):
    method __call__ (line 222) | def __call__(self, prompts, max_n_frames):
  function main (line 236) | def main(argv):

FILE: lwm/vision_generation.py
  function main (line 44) | def main(argv):

FILE: lwm/vision_llama.py
  class VideoLLaMAConfig (line 27) | class VideoLLaMAConfig(LLaMAConfig):
    method __init__ (line 30) | def __init__(self, vision_vocab_size=8448, tie_vision_embeddings=False...
    method get_partition_rules (line 37) | def get_partition_rules(scan_layers=False, scan_axis=0):
    method load_config (line 107) | def load_config(cls, path):
  class FlaxVideoLLaMAPreTrainedModel (line 121) | class FlaxVideoLLaMAPreTrainedModel(FlaxPreTrainedModel):
    method __init__ (line 131) | def __init__(
    method init_cache (line 143) | def init_cache(self, batch_size, max_length):
    method init_weights (line 156) | def init_weights(self, rng, input_shape, params=None):
    method __call__ (line 179) | def __call__(
  class FlaxVideoLLaMAModule (line 255) | class FlaxVideoLLaMAModule(nn.Module):
    method setup (line 261) | def setup(self):
    method __call__ (line 283) | def __call__(
  class FlaxVideoLLaMAForCausalLMModule (line 346) | class FlaxVideoLLaMAForCausalLMModule(nn.Module):
    method setup (line 352) | def setup(self):
    method __call__ (line 371) | def __call__(
  class FlaxVideoLLaMAForCausalLM (line 444) | class FlaxVideoLLaMAForCausalLM(FlaxVideoLLaMAPreTrainedModel):
    method prepare_inputs_for_generation (line 447) | def prepare_inputs_for_generation(
    method update_inputs_for_generation (line 468) | def update_inputs_for_generation(self, model_outputs, model_kwargs):
    method _sample_vision (line 476) | def _sample_vision(
    method generate_vision (line 583) | def generate_vision(

FILE: lwm/vqgan.py
  class VQGAN (line 14) | class VQGAN:
    method __init__ (line 15) | def __init__(self, vqgan_checkpoint, replicate=False):
    method _wrap_fn (line 26) | def _wrap_fn(self, fn):
    method _encode (line 33) | def _encode(self):
    method _decode (line 43) | def _decode(self):
    method encode (line 52) | def encode(self, pixel_values):
    method decode (line 55) | def decode(self, encoding):
  class VQGANConfig (line 59) | class VQGANConfig(PretrainedConfig):
    method __init__ (line 62) | def __init__(
    method get_default_config (line 93) | def get_default_config(cls, updates=None):
    method load_config (line 101) | def load_config(cls, path):
  class VQGANModel (line 105) | class VQGANModel(nn.Module):
    method setup (line 108) | def setup(self):
    method encode (line 117) | def encode(self, pixel_values):
    method decode (line 130) | def decode(self, encoding, is_codebook_indices=True):
    method __call__ (line 143) | def __call__(self, pixel_values):
  class Encoder (line 149) | class Encoder(nn.Module):
    method __call__ (line 153) | def __call__(self, pixel_values):
  class Decoder (line 167) | class Decoder(nn.Module):
    method __call__ (line 171) | def __call__(self, hidden_states):
  class VectorQuantizer (line 187) | class VectorQuantizer(nn.Module):
    method __call__ (line 192) | def __call__(self, z, encoding_indices=None):
  class DownsamplingBlock (line 224) | class DownsamplingBlock(nn.Module):
    method __call__ (line 229) | def __call__(self, hidden_states):
  class ResnetBlock (line 242) | class ResnetBlock(nn.Module):
    method __call__ (line 248) | def __call__(self, hidden_states):
  class AttnBlock (line 266) | class AttnBlock(nn.Module):
    method __call__ (line 268) | def __call__(self, hidden_states):
  class Downsample (line 286) | class Downsample(nn.Module):
    method __call__ (line 290) | def __call__(self, hidden_states):
  class Upsample (line 306) | class Upsample(nn.Module):
    method __call__ (line 310) | def __call__(self, hidden_states):
  class UpsamplingBlock (line 322) | class UpsamplingBlock(nn.Module):
    method __call__ (line 327) | def __call__(self, hidden_states):
  class MidBlock (line 340) | class MidBlock(nn.Module):
    method __call__ (line 346) | def __call__(self, hidden_states):

FILE: scripts/eval_needle.py
  class LLMNeedleHaystackTester (line 47) | class LLMNeedleHaystackTester:
    method __init__ (line 64) | def __init__(self,
    method generate_random_number (line 109) | def generate_random_number(self, num_digits):
    method logistic (line 114) | def logistic(self, x, L=100, x0=50, k=.1):
    method read_context_files (line 121) | def read_context_files(self, n):
    method encode_and_trim (line 135) | def encode_and_trim(self, context, context_length):
    method create_contexts (line 141) | def create_contexts(self, needle_rnd_number, insert_needle, random_cit...
    method insert_needle (line 162) | def insert_needle(self, needle, context, depth_percent, context_length):
    method generate_context (line 199) | def generate_context(self, needle, trim_context, context_length, depth...
    method compute_max_input_length (line 203) | def compute_max_input_length(self, context_length, buffer=1024):
    method run_test (line 209) | def run_test(self):
    method print_start_test_summary (line 295) | def print_start_test_summary(self):
    method start_test (line 303) | def start_test(self):
  class Sampler (line 310) | class Sampler:
    method __init__ (line 311) | def __init__(self):
    method block_size (line 319) | def block_size(self):
    method data_dim (line 324) | def data_dim(self):
    method _load_model (line 327) | def _load_model(self):
    method _forward_generate (line 375) | def _forward_generate(self):
    method __call__ (line 402) | def __call__(self, prompts, max_input_length):
  function main (line 427) | def main(argv):

FILE: scripts/eval_needle_multi.py
  class LLMNeedleHaystackTester (line 50) | class LLMNeedleHaystackTester:
    method __init__ (line 67) | def __init__(self,
    method generate_random_number (line 114) | def generate_random_number(self, num_digits):
    method logistic (line 119) | def logistic(self, x, L=100, x0=50, k=.1):
    method read_context_files (line 126) | def read_context_files(self, n):
    method encode_and_trim (line 137) | def encode_and_trim(self, context, context_length):
    method create_contexts (line 143) | def create_contexts(self, needles_info, random_cities_retrieve, contex...
    method insert_needle (line 166) | def insert_needle(self, needle, context, depth_percent, context_length):
    method generate_context (line 203) | def generate_context(self, needle, trim_context, context_length, depth...
    method compute_max_input_length (line 207) | def compute_max_input_length(self, context_length, buffer=1024):
    method run_test (line 214) | def run_test(self):
    method print_start_test_summary (line 304) | def print_start_test_summary(self):
    method start_test (line 312) | def start_test(self):
  class Sampler (line 319) | class Sampler:
    method __init__ (line 320) | def __init__(self):
    method block_size (line 328) | def block_size(self):
    method data_dim (line 333) | def data_dim(self):
    method _load_model (line 336) | def _load_model(self):
    method _forward_generate (line 384) | def _forward_generate(self):
    method __call__ (line 411) | def __call__(self, prompts, max_input_length):
  function main (line 436) | def main(argv):
Condensed preview — 26 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (259K chars).
[
  {
    "path": ".gitignore",
    "chars": 2121,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 12812,
    "preview": "# Large World Model (LWM)\n\n[[Project]](https://largeworldmodel.github.io/)\n[[Paper]](https://arxiv.org/abs/2402.08268)\n["
  },
  {
    "path": "docs/data.md",
    "chars": 6805,
    "preview": "# Data\n\nWe support two types of datasets: Huggingface dataset and JSON dataset. The dataset modules are implemented in t"
  },
  {
    "path": "docs/sharding.md",
    "chars": 2082,
    "preview": "# Sharding\n\nSharding is a technique to partition the computation and the model across multiple accelerators.\nThis codeba"
  },
  {
    "path": "gpu_requirements.txt",
    "chars": 494,
    "preview": "-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\njax[cuda12]==0.4.29\nflax==0.8.4\noptax==0.2.2\n--ext"
  },
  {
    "path": "lwm/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "lwm/data.py",
    "chars": 35684,
    "preview": "import time\nimport random\nfrom functools import partial\nimport json\nfrom multiprocessing import Pool\n\nfrom tux import op"
  },
  {
    "path": "lwm/llama.py",
    "chars": 42950,
    "preview": "import os\nfrom shutil import copyfile\nfrom typing import Any, Dict, List, Optional, Tuple, Union\nimport json\nimport temp"
  },
  {
    "path": "lwm/train.py",
    "chars": 15292,
    "preview": "import pprint\nimport os\nfrom functools import partial\n\nfrom tqdm import tqdm, trange\nimport numpy as np\nfrom absl.app im"
  },
  {
    "path": "lwm/vision_chat.py",
    "chars": 9687,
    "preview": "from absl.app import run\nimport math\nfrom tqdm import tqdm\nfrom PIL import Image\nimport decord\nfrom functools import cac"
  },
  {
    "path": "lwm/vision_generation.py",
    "chars": 9561,
    "preview": "from absl.app import run\nfrom tqdm import tqdm\nimport imageio\nimport numpy as np\nfrom PIL import Image\nfrom transformers"
  },
  {
    "path": "lwm/vision_llama.py",
    "chars": 32218,
    "preview": "from typing import Any, Dict, List, Optional, Tuple, Union\nimport json\nimport warnings\nimport copy\n\nimport jax\nimport ja"
  },
  {
    "path": "lwm/vqgan.py",
    "chars": 12723,
    "preview": "from typing import Optional\nfrom functools import cached_property, partial\nimport pickle\nimport numpy as np\nimport jax\ni"
  },
  {
    "path": "scripts/create_needle_data.py",
    "chars": 477,
    "preview": "import os\nimport argparse\nimport json\nfrom tqdm import tqdm\nfrom datasets import load_dataset\n\nparser = argparse.Argumen"
  },
  {
    "path": "scripts/eval_needle.py",
    "chars": 19579,
    "preview": "from absl.app import run\nimport time\nimport json\nimport math\nimport os\nfrom tqdm import tqdm\nimport random\nfrom functool"
  },
  {
    "path": "scripts/eval_needle_multi.py",
    "chars": 20142,
    "preview": "from absl.app import run\nimport glob\nimport time\nimport json\nimport math\nimport os\nfrom tqdm import tqdm\nimport random\nf"
  },
  {
    "path": "scripts/run_eval_needle.sh",
    "chars": 1195,
    "preview": "#! /bin/bash\n\nexport SCRIPT_DIR=\"$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nexport PROJECT_DI"
  },
  {
    "path": "scripts/run_eval_needle_multi.sh",
    "chars": 1257,
    "preview": "#! /bin/bash\n\nexport SCRIPT_DIR=\"$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nexport PROJECT_DI"
  },
  {
    "path": "scripts/run_sample_image.sh",
    "chars": 1307,
    "preview": "#! /bin/bash\n\nexport SCRIPT_DIR=\"$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nexport PROJECT_DI"
  },
  {
    "path": "scripts/run_sample_video.sh",
    "chars": 1367,
    "preview": "#! /bin/bash\n\nexport SCRIPT_DIR=\"$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nexport PROJECT_DI"
  },
  {
    "path": "scripts/run_train_text.sh",
    "chars": 2494,
    "preview": "#! /bin/bash\n\nexport SCRIPT_DIR=\"$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nexport PROJECT_DI"
  },
  {
    "path": "scripts/run_train_vision_text.sh",
    "chars": 2835,
    "preview": "#! /bin/bash\n\nexport SCRIPT_DIR=\"$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nexport PROJECT_DI"
  },
  {
    "path": "scripts/run_vision_chat.sh",
    "chars": 1271,
    "preview": "#! /bin/bash\n\nexport SCRIPT_DIR=\"$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nexport PROJECT_DI"
  },
  {
    "path": "scripts/sample_pyt.py",
    "chars": 758,
    "preview": "import argparse\nfrom transformers import LlamaForCausalLM, LlamaTokenizer\n\nparser = argparse.ArgumentParser()\nparser.add"
  },
  {
    "path": "tpu_requirements.sh",
    "chars": 3960,
    "preview": "#! /bin/bash\n\nsudo apt-get update && sudo apt-get install -y \\\n    build-essential \\\n    python-is-python3 \\\n    tmux \\\n"
  }
]

About this extraction

This page contains the full source code of the LargeWorldModel/LWM GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 26 files (244.6 KB), approximately 56.9k tokens, and a symbol index with 210 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!