Repository: LargeWorldModel/LWM
Branch: main
Commit: f45d2b70bda2
Files: 26
Total size: 244.6 KB
Directory structure:
gitextract_q2zwva4c/
├── .gitignore
├── LICENSE
├── README.md
├── docs/
│ ├── data.md
│ └── sharding.md
├── gpu_requirements.txt
├── lwm/
│ ├── __init__.py
│ ├── data.py
│ ├── llama.py
│ ├── train.py
│ ├── vision_chat.py
│ ├── vision_generation.py
│ ├── vision_llama.py
│ └── vqgan.py
├── scripts/
│ ├── create_needle_data.py
│ ├── eval_needle.py
│ ├── eval_needle_multi.py
│ ├── run_eval_needle.sh
│ ├── run_eval_needle_multi.sh
│ ├── run_sample_image.sh
│ ├── run_sample_video.sh
│ ├── run_train_text.sh
│ ├── run_train_vision_text.sh
│ ├── run_vision_chat.sh
│ └── sample_pyt.py
└── tpu_requirements.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# local
jobs/
local/
.vscode/
data/
*.model
*.npy
*.jsonl
*.pkl
*.json
__pycache__/
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# Large World Model (LWM)
[[Project]](https://largeworldmodel.github.io/)
[[Paper]](https://arxiv.org/abs/2402.08268)
[[Models]](https://huggingface.co/LargeWorldModel)
**Large World Model (LWM)** is a general-purpose large-context multimodal autoregressive model. It is trained on a large dataset of diverse long videos and books using RingAttention, and can perform language, image, and video understanding and generation.
## Approach
Current language models fall short in understanding aspects of the world not easily described in words, and struggle with complex, long-form tasks. Video sequences offer valuable temporal information absent in language and static images, making them attractive for joint modeling with language. Such models could develop a understanding of both human textual knowledge and the physical world, enabling broader AI capabilities for assisting humans. However, learning from millions of tokens of video and language sequences poses challenges due to memory constraints, computational complexity, and limited datasets. To address these challenges, we curate a large dataset of diverse videos and books, utilize the RingAttention technique to scalably train on long sequences, and gradually increase context size from 4K to 1M tokens. This paper makes the following contributions: (a) Largest context size neural network: We train one of the largest context size transformers on long video and language sequences, setting new benchmarks in difficult retrieval tasks and long video understanding. (b) Solutions for overcoming vision-language training challenges, including using masked sequence packing for mixing different sequence lengths, loss weighting to balance language and vision, and model-generated QA dataset for long sequence chat. (c) A highly-optimized implementation with RingAttention, masked sequence packing, and other key features for training on millions-length multimodal sequences. (d) Fully open-sourced a family of 7B parameter models capable of processing long text documents (LWM-Text, LWM-Text-Chat) and videos (LWM, LWM-Chat) of over 1M tokens.
This work paves the way for training on massive datasets of long video and language to develop understanding of both human knowledge and the multimodal world, and broader capabilities.
## LWM Capabilities
LWM can retrieval facts across 1M context with high accuracy.
LWM can answer questions over 1 hour YouTube video.
LWM can chat with images.
LWM can generate videos and images from text.
## Setup
This codebase is supported on Ubuntu and has not been tested on Windows or macOS. We recommend using TPUs for training and inference, although it is also possible to use GPUs. On TPU, the code is highly optimized with Jax's Pallas and can achieve high MFUs with RingAttention at very large context sizes. On GPU, the code is based on XLA and is not as optimized as it is for TPU.
Install the requirements with:
```
conda create -n lwm python=3.10
conda activate lwm
pip install -r gpu_requirements.txt
```
or set up TPU VM with:
```
sh tpu_requirements.sh
```
## Available models
There are language-only and video-language versions, offering context sizes from 32K, to 128K, 256K and 1M tokens. The vision-language models are available only in Jax, and the language-only models are available in both PyTorch and Jax. Below are the names of the available models and their corresponding context sizes and capabilities:
| Model Name | Context Size | Language or Vision-Language | Chat or Base | URL |
|--------------------|--------------|-----------------------------|--------------|----------------------------------------------------------------------------------------------------------------------------------------------|
| LWM-Text-Chat-128K | 128K | Language | Chat | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-128K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-128K-Jax)] |
| LWM-Text-Chat-256K | 256K | Language | Chat | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-256K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-256K-Jax)] |
| LWM-Text-Chat-512K | 512K | Language | Chat | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-512K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-512K-Jax)] |
| LWM-Text-Chat-1M | 1M | Language | Chat | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-1M)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-1M-Jax)] |
| LWM-Text-128K | 128K | Language | Base | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-128K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-128K-Jax)] |
| LWM-Text-256K | 256K | Language | Base | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-256K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-256K-Jax)] |
| LWM-Text-512K | 512K | Language | Base | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-512K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-512K-Jax)] |
| LWM-Text-1M | 1M | Language | Base | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-1M)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-1M-Jax)] |
| LWM-Chat-32K | 32K | Vision-Language | Chat | [[Jax](https://huggingface.co/LargeWorldModel/LWM-32K-Jax)] |
| LWM-Chat-128K | 128K | Vision-Language | Chat | [[Jax](https://huggingface.co/LargeWorldModel/LWM-128K-Jax)] |
| LWM-Chat-1M | 1M | Vision-Language | Chat | [[Jax](https://huggingface.co/LargeWorldModel/LWM-1M-Jax)] |
## Code structure
Use `scan_query_chunk_size` and `scan_key_chunk_size` to control the block size in blockwise compute of the self-attention. Use `scan_mlp_chunk_size` to control the block size in blockwise compute of the feedforward network. Use `scan_attention=True` and `scan_mlp=True` to enable/disable blockwise compute in the self-attention and feed-forward network.
You can use `mesh_dim=dp, fsdp, tp, sp` to control the degree of parallelism and RingAttention. It is a string of 4 integers separated by commas, representing the number of data parallelism, fully sharded data parallelism, tensor parallelism, and sequence parallelism.
For example, `mesh_dim='1,64,4,1'` means 1 data parallelism, 64 fully sharded data parallelism, 4 tensor parallelism, and 1 sequence parallelism. `mesh_dim='1,1,4,64'` means 1 data parallelism, 1 fully sharded data parallelism, 4 tensor parallelism, and 64 sequence parallelism for RingAttention.
## Running Jax Models
In this section, we provide instructions on how to run each of the provided scripts. For each script, you may need to fill in your own paths and values in the variables described in the beginning of each script.
To run each of the following scripts, use `bash .sh`:
- Language model training: `bash scripts/run_train_text.sh`
- Vision-Language model training: `bash scripts/run_train_vision_text.sh`
- Single Needle Evals (Language Model): `bash scripts/run_eval_needle.sh`
- Multi Needle Evals (Language Model): `bash scripts/run_eval_needle_multi.sh`
- Sampling images (Vision-Language Model): `bash scripts/run_sample_image.sh`
- Sampling videos (Vision-LanguageModel): `bash scripts/run_sample_video.sh`
- Image / Video understanding (Vision-Language Model): `bash scripts/run_vision_chat.sh`
By default the `mesh_dim` argument puts all devices on `tp` (tensor parallelism). For longer sequences, you may want to include `sp`, which is the last dimension in the `mesh_dim`.
When running needle evals, you may need to adjust the `theta` and `max_sequence_length` arguments in the scripts depending on the model. Below shows the correct values for each model.
| | LWM-Text-128K / LWM-Text-Chat-128K | LWM-Text-256K / LWM-Text-Chat-256K | LWM-Text-512K / LWM-Text-Chat-512K | LWM-Text-1M / LWM-Text-Chat-1M |
|---------------------|:-----------------------------------:|:-----------------------------------:|:----------------------------------:|:------------------------------:|
| theta | 10000000 | 10000000 | 25000000 | 50000000 |
| max_sequence_length | 131072 | 262144 | 524288 | 1048576 |
An example of filling out a script (`run_sample_video.sh`) is as follows
```bash
#! /bin/bash
export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
export llama_tokenizer_path="LargeWorldModel/LWM-Text-1M"
export vqgan_checkpoint="/path/to/ckpt/folder/vqgan"
export lwm_checkpoint="params::/path/to/ckpt/folder/params"
python3 -u -m lwm.vision_generation \
--prompt='Fireworks over the city' \
--output_file='fireworks.mp4' \
--temperature_image=1.0 \
--temperature_video=1.0 \
--top_k_image=8192 \
--top_k_video=1000 \
--cfg_scale_image=5.0 \
--cfg_scale_video=1.0 \
--vqgan_checkpoint="$vqgan_checkpoint" \
--n_frames=8 \
--mesh_dim='!1,1,-1,1' \
--dtype='fp32' \
--load_llama_config='7b' \
--update_llama_config="dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \
--load_checkpoint="$lwm_checkpoint" \
--tokenizer="$llama_tokenizer_path"
read
```
## Needle Haystack Data
Run `python scripts/create_needle_data.py`
## Running PyTorch Models
Only text and text chat models are currently supported for PyTorch inference. PyTorch models can be loaded as Hugging Face `LlamaForCausalLM` models. Run `python scripts/sample_pyt.py` to sample. You may need to separately install `torch`.
## Documentation
For more details on the codebase, please refer to the [data.md](docs/data.md) and [sharding.md](docs/sharding.md).
The [data.md](docs/data.md) provides details on the data processing and the [sharding.md](docs/sharding.md) provides details on the sharding and parallelism.
## If you have issues
This is based on the [codebase](https://github.com/haoliuhl/ringattention) of RingAttention, with the necessary features for vision-language training. The training and inference have been tested on both TPUv3 and TPUv4.
If you encounter bugs, please open a GitHub issue!
## Citation
If you use this codebase, or otherwise found our work valuable, please cite:
```
@article{liu2023world,
title={World Model on Million-Length Video and Language with RingAttention},
author={Liu, Hao and Yan, Wilson and Zaharia, Matei and Abbeel, Pieter},
journal={arXiv preprint},
year={2024},
}
@article{liu2023ring,
title={Ring Attention with Blockwise Transformers for Near-Infinite Context},
author={Liu, Hao and Zaharia, Matei and Abbeel, Pieter},
journal={International Conference on Learning Representations},
year={2024}
}
@article{liu2023blockwise,
title={Blockwise Parallel Transformer for Large Context Models},
author={Liu, Hao and Abbeel, Pieter},
journal={Advances in neural information processing systems},
year={2023}
}
```
## License
LWM's code is released under the Apache 2.0 License. See [LICENSE](https://github.com/LargeWorldModel/lwm/blob/main/LICENSE) for further details. The models are released under the Llama-2 license.
================================================
FILE: docs/data.md
================================================
# Data
We support two types of datasets: Huggingface dataset and JSON dataset. The dataset modules are implemented in the [data.py](/lwm/data.py) file.
Configuration requires dataset type, text processor, and dataset specific configurations.
The following is an example of using the Huggingface dataset to train a model:
```bash
python -m lwm.train \
--train_dataset.text_processor.fields='text' \
--train_dataset.type='huggingface' \
--train_dataset.huggingface_dataset.path='openwebtext'
```
In this example, we select the Huggingface dataset by specifying the `type` of
`train_dataset` to be `huggingface`. We then specify the path to the dataset,
which is `c4` in this case. The examples loaded from the dataset will be processed
by a TextProcessor, which is configured by the `text_processor` field.
The following options are supported for the dataset module:
* `type`: The type of the dataset. Supported values are `huggingface` and `json`.
* `text_processor`: The configuration of the TextProcessor used to process the
loaded examples.
* `huggingface_dataset`: The configuration of the Huggingface dataset.
* `json_dataset`: The configuration of the JSON dataset.
For huggingface dataset, we expect loading examples from a Huggingface dataset.
* `path`: The path to the dataset. Same as the `path` argument in
`datasets.load_dataset`.
* `name`: Name of the dataset within the path. Same as the `name` argument in
`datasets.load_dataset`.
* `split`: The split of the dataset. Same as the `split` argument in
`datasets.load_dataset`.
* `streaming`: Whether to stream the dataset. Same as the `streaming` argument
in `datasets.load_dataset`.
* `seq_length`: The length of the tokenized sequence.
* `batch_size`: Batch size of tokenized examples.
For JSON dataset, we expect loading examples from a text file, where each line where each line represents a
JSON encoded dictionary. Here are the configurable options for JSON dataset:
* `path`: Path to the text file. The file can be located on the local file system
or on Google Cloud Storage bucket.
* `seq_length`: The length of the tokenized sequence.
* `batch_size`: Batch size of tokenized examples.
* `start_seek_loc`: The starting seek location in the file. This is useful when
you want to resume training from a particular location in the file.
* `index_at_start`: The counting index at the beginning. This is useful to
keep the index count when resuming from a particular location in the file.
Note that this is only for logging purpose, and does not affect the actual
examples starting from. To start from a different example in the dataset,
you should use the `start_seek_loc` option.
* `tokenizer_processes`: The number of processes to use for tokenization.
Tokenization is done in parallel to speed up the loading process.
A JSON dataset can be generated as follows:
```python
from datasets import load_dataset
import json
from multiprocessing import Pool, cpu_count
dataset = load_dataset("openwebtext")
split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
split_dataset['val'] = split_dataset.pop('test')
def save_split(split):
with open(f"openwebtext_{split}.jsonl", "w") as f:
for example in split_dataset[split]:
json.dump({"text": example["text"]}, f)
f.write("\n")
with Pool(cpu_count()) as p:
p.map(save_split, ["train", "val"])
```
This generates two files, `openwebtext_train.jsonl` and `openwebtext_val.jsonl`, which can be used as the dataset for training. Both files contain a single field, `text`, which is the text to be processed by the model.
For example, to train a model using the `openwebtext_train.jsonl` file, you can use the following command:
```bash
python -m lwm.train \
--train_dataset.text_processor.fields='text' \
--train_dataset.type='json' \
--train_dataset.json_dataset.path='openwebtext_train.jsonl' \
```
For vision-langauge training, we recommend using the JSON dataset, as it allows you to pre-tokenize vision (images and videos), and load the tokenized vision along with the text.
Each loaded example is a dictionary, which will be processed by a TextProcessor
## Text Processor
We use the `TextProcessor` class to process the loaded examples from a dataset. This allows us to flexibly process various formats.
Each input example is a dictionary of multiple text fields. The TextProcessor will
process text fields according to its configurations, and return the final tokens.
Here are the configurable options for TextProcessor:
* `fields`: A comma separated list of text fields to process.
* `fields_from_example`: Whether to use the keys of the input example as the
text fields to process. If this option is set, the `fields` argument will
be ignored.
* `subfield_separator`: The text separator to use when concatenating subfields
of a texts.
* `add_eos_token`: Whether to add an EOS token to the end of the text.
* `prepend_text`: The text to prepended to the beginning.
The most important configuration for TextProcessor is the `fields` argument. It
is a comma separated list of text fields to process. Each field consists of one
or more subfields, which are separated by a `+`. Each subfield represent a key
used to extract the text from the input example dictionary. The TextProcessor
joins the extracted subfields of texts with the `subfield_separator` in the text
level and then tokenize the joined text. Finally, the TextProcessor will concatenate
the tokenized text fields at the token level, and add the EOS token if specified.
Other than the keys in the input example, you can also use the following special
keys to indicate a special token for a text field:
* `<|bos|>`: Beginning of sentence token.
* `<|eos|>`: End of sentence token.
For each text field, you can encapulate the subfields with `[]` to specify that
the loss should not be computed for this field. Doing so will make the loss
masks to be 0 for this field. This is useful when you want to use the text field
as a prompt for the model.
To give a concrete example, if the input example looks like this:
```python
{
'vision': 'VQ tokens of a picture of a cat',
'question': 'what is the color of the cat',
'answer': 'the color of the cat is yellow',
}
```
To use the `vision` and `question` as the input text, and `answer` as the target,
we can specify the following configuration for the `fields` argument:
```
[vision+question],answer
```
The `vision+question` indicates that the `vision` and `question` should be joined
togather with the `subfield_separator`, which is a space by default. The `[]`
indicates that the loss should not be computed for this field. The `answer` field
is then concatenated at the token level, where the loss will be computed.
================================================
FILE: docs/sharding.md
================================================
# Sharding
Sharding is a technique to partition the computation and the model across multiple accelerators.
This codebase supports flexible model and data parallelism for training and serving.
The sharding can be specified using the `mesh_dim` command line argument. The `mesh_dim` is a
comma separated list of integers representing the parallelism mesh axis dimensions. One of the
axis dimensions can be `-1`, which means that the axis dimension will be inferred based on the
total number of accelerators.
The first axis of the mesh is used for data parallelism (`dp`), the second axis used for fully sharded
data parallelism (`fsdp`), the third axis is used for tensor parallelism (`tp`), the last axis is used for
sequence parallelism (required for ring attention) (`sp`).
For example, `mesh_dim='1,64,4,1'` means 1 data parallelism, 64 fully sharded data parallelism, 4 tensor parallelism, and 1 sequence parallelism. While `mesh_dim='1,1,4,64'` means 1 data parallelism, 1 fully sharded data parallelism, 4 tensor parallelism, and 64 sequence parallelism for RingAttention.
Your total number of accelerators should be equal to the product of the mesh dimensions. For example, `mesh_dim='1,64,4,1'` requires 64 accelerators, and `mesh_dim='1,1,4,64'` requires 256 accelerators.
In general, you want to use the largest possible mesh dimension for `fsdp`. Such as `mesh_dim='1,64,1,1'` is preferred over `mesh_dim='8,8,1,1'` because the former has larger `fsdp` dimensions, which allows overlapping of computation and communication, and thus better performance.
The batch size (number of sequences per batch) should be larger than or equal to `fsdp * dp`. If you think the batch size is too large, you can allocate more accelerators to `tp` and `sp` to increase the model size and sequence length.
Using `sp` to control the sequence parallelism is required to use RingAttention. `sp=8` means sharding sequence length by 8, and `sp=1` means no sharding.
For models that use standard attention, you can set `sp=1` and use `dp`, `fsdp`, and `tp` to control the parallelism.
================================================
FILE: gpu_requirements.txt
================================================
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12]==0.4.29
flax==0.8.4
optax==0.2.2
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.0.0
transformers==4.40.0
ringattention @ git+https://github.com/haoliuhl/ringattention.git
datasets
einops
tqdm
ml_collections
wandb
gcsfs
requests
typing-extensions
sentencepiece
tux @ git+https://github.com/haoliuhl/tux.git
Pillow
ffmpeg-python
ipdb
imageio[ffmpeg]
opencv-python
decord
ffmpeg-python
h5py
psutil
================================================
FILE: lwm/__init__.py
================================================
================================================
FILE: lwm/data.py
================================================
import time
import random
from functools import partial
import json
from multiprocessing import Pool
from tux import open_file
from ml_collections import ConfigDict
import numpy as np
import jax
from jax.experimental.multihost_utils import host_local_array_to_global_array
from jax.sharding import PartitionSpec as PS
from datasets import load_dataset
class DatasetFactory(object):
""" Datset builder class. """
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.type = 'huggingface'
config.text_processor = TextProcessor.get_default_config()
config.huggingface_dataset = HuggingfaceDataset.get_default_config()
config.json_dataset = JsonDataset.get_default_config()
config.vision_text_processor = VisionTextProcessor.get_default_config()
config.json_vision_dataset = JsonVisionDataset.get_default_config()
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def load_dataset(cls, config, tokenizer, **kwargs):
config = cls.get_default_config(config)
if config.type == 'huggingface':
text_processor = TextProcessor(config.text_processor, tokenizer)
return HuggingfaceDataset(
config.huggingface_dataset, tokenizer, text_processor, **kwargs
)
elif config.type == 'json':
text_processor = TextProcessor(config.text_processor, tokenizer)
return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs)
elif config.type == 'json_vision':
vision_text_processor = VisionTextProcessor(config.vision_text_processor, tokenizer)
return JsonVisionDataset(config.json_vision_dataset, tokenizer, vision_text_processor, **kwargs)
else:
raise ValueError(f'Unknown dataset type: {config.type}')
def __init__(self):
raise ValueError('DatasetFactory is a static class and should not be instantiated.')
class TextProcessor(object):
""" Example processor that converts a dictionary of texts into tokens. """
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.fields_from_example = ''
config.fields = ''
config.subfield_separator = ' '
config.add_bos_token = True
config.add_eos_token = True
config.prepend_text = ''
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, tokenizer):
self.config = self.get_default_config(config)
assert self.config.fields != '' or self.config.fields_from_example != '', (
'Either fields or fields_from_example must be specified.'
)
self.tokenizer = tokenizer
def __call__(self, example, has_aux=False, add_bos_token=True, add_eos_token=True):
if has_aux:
example, *aux = example
else:
aux = tuple()
token_buffer = []
loss_mask_buffer = []
if add_bos_token and self.config.add_bos_token:
token_buffer.append(self.tokenizer.bos_token_id)
loss_mask_buffer.append(0.0)
if self.config.fields_from_example != '':
fields = example[self.config.fields_from_example].split(',')
else:
fields = self.config.fields.split(',')
for i, field in enumerate(fields):
if field.startswith('[') and field.endswith(']'):
# No loss for this field.
field = field[1:-1]
mask = 0.0
else:
mask = 1.0
if field == '<|bos|>':
token_buffer.append(self.tokenizer.bos_token_id)
loss_mask_buffer.append(mask)
elif field == '<|eos|>':
token_buffer.append(self.tokenizer.eos_token_id)
loss_mask_buffer.append(mask)
else:
subfields = field.split('+')
text = self.config.subfield_separator.join(
[example[subfield] for subfield in subfields]
)
if i == 0:
text = self.config.prepend_text + text
tokens = self.tokenizer.encode(text, add_special_tokens=False)
token_buffer.extend(tokens)
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
if add_eos_token and self.config.add_eos_token:
token_buffer.append(self.tokenizer.eos_token_id)
loss_mask_buffer.append(1.0)
return token_buffer, loss_mask_buffer, *aux
class VisionTextProcessor(object):
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.fields_from_example = ''
config.subfield_separator = ' '
config.add_bos_token = True
config.add_eos_token = True
config.prepend_text = ''
config.fields_index = -1
config.eof_token = 8192 # denotes end of each frame for video generation
config.eov_token = 8193 # denotes end of vision generation
config.n_tokens_per_frame = 256 # 16 x 16 VQ codes
config.max_n_frames = -1
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, tokenizer):
self.config = self.get_default_config(config)
assert self.config.fields_from_example != '', (
'fields_from_example must be specified.'
)
self.tokenizer = tokenizer
self.vision_start = tokenizer.encode('')
self.vision_end = tokenizer.encode('')
def __call__(self, example, has_aux=False, add_bos_token=True, add_eos_token=True):
if has_aux:
example, *aux = example
else:
aux = tuple()
rand_state = random.Random(aux[-1]) # makes augmentations deterministic by line number
token_buffer = []
loss_mask_buffer = []
vision_mask = []
fields = example[self.config.fields_from_example]
if isinstance(fields, (tuple, list)):
if self.config.fields_index >= 0:
fields = fields[self.config.fields_index]
else:
# seed based on line number
fields = rand_state.choice(fields)
fields = fields.split(',')
if add_bos_token and self.config.add_bos_token:
token_buffer.append(self.tokenizer.bos_token_id)
loss_mask_buffer.append(0.0)
vision_mask.append(False)
for i, field in enumerate(fields):
if field.startswith('[') and field.endswith(']'):
# No loss for this field.
field = field[1:-1]
mask = 0.0
else:
mask = 1.0
if field == '<|bos|>':
token_buffer.append(self.tokenizer.bos_token_id)
loss_mask_buffer.append(mask)
vision_mask.append(False)
elif field == '<|eos|>':
token_buffer.append(self.tokenizer.eos_token_id)
loss_mask_buffer.append(mask)
vision_mask.append(False)
elif 'vision' in field:
vision_tokens = example[field]
n_frames = int(len(vision_tokens) / self.config.n_tokens_per_frame)
if self.config.max_n_frames > 0 and n_frames > self.config.max_n_frames: # uniformly select
idxs = np.linspace(0, n_frames - 1, self.config.max_n_frames).astype(int)
new_vision_tokens = []
for idx in idxs:
new_vision_tokens.extend(vision_tokens[idx * self.config.n_tokens_per_frame:(idx + 1) * self.config.n_tokens_per_frame])
vision_tokens = new_vision_tokens
n_frames = self.config.max_n_frames
assert int(len(vision_tokens) / self.config.n_tokens_per_frame) == n_frames, (int(len(vision_tokens) / self.config.n_tokens_per_frame), n_frames)
assert n_frames > 0, len(vision_tokens)
tokens = list(self.vision_start)
for j in range(n_frames):
tokens.extend(vision_tokens[j*self.config.n_tokens_per_frame:(j+1)*self.config.n_tokens_per_frame])
if j == n_frames - 1: # last frame
tokens.append(self.config.eov_token)
else:
tokens.append(self.config.eof_token)
tokens.extend(self.vision_end)
token_buffer.extend(tokens)
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
vision_mask.extend([False] * len(self.vision_start))
vision_mask.extend([True] * (self.config.n_tokens_per_frame * n_frames + n_frames)) # include extra eof/eov token at the end of each frame
vision_mask.extend([False] * len(self.vision_end))
else:
subfields = field.split('+')
text = self.config.subfield_separator.join(
[example[subfield] for subfield in subfields]
)
if i == 0:
text = self.config.prepend_text + text
tokens = self.tokenizer.encode(text)
token_buffer.extend(tokens)
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
vision_mask.extend([False] * len(tokens))
if add_eos_token and self.config.add_eos_token:
token_buffer.append(self.tokenizer.eos_token_id)
loss_mask_buffer.append(1.0)
vision_mask.append(False)
assert len(token_buffer) == len(loss_mask_buffer) == len(vision_mask), (len(token_buffer), len(loss_mask_buffer), len(vision_mask))
keep = True
return token_buffer, loss_mask_buffer, vision_mask, keep, *aux
class HuggingfaceDataset(object):
""" Huggingface dataset, where the dataset is loaded using the huggingface
datasets.load_dataset() function.
"""
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.path = 'c4'
config.name = 'en'
config.split = 'train'
config.streaming = False
config.seq_length = 1024
config.batch_size = 8
config.always_start_with_bos = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, tokenizer, text_processor):
self.config = self.get_default_config(config)
name = self.config.name if self.config.name != '' else None
split = self.config.split if self.config.split != '' else None
self._tokenizer = tokenizer
self._text_processor = text_processor
self._dataset = load_dataset(
self.config.path, name, split=split, streaming=self.config.streaming
)
def __iter__(self):
chunk_size = self.config.batch_size * self.config.seq_length
total_tokens = 0
while True:
token_buffer = []
loss_mask_buffer = []
for index, example in enumerate(self._dataset):
tokens, loss_masks = self.text_processor(example)
token_buffer.extend(tokens)
loss_mask_buffer.extend(loss_masks)
while len(token_buffer) > chunk_size + 1:
total_tokens += chunk_size
metrics = {
'dataset_example_index': index,
'dataset_total_tokens': total_tokens,
}
batch = {
'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
self.config.batch_size, -1
),
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
self.config.batch_size, -1
),
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
self.config.batch_size, -1
),
}
if self.config.always_start_with_bos:
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
yield batch, metrics
token_buffer = token_buffer[chunk_size:]
loss_mask_buffer = loss_mask_buffer[chunk_size:]
def get_state_dict(self):
return dict(config=self.config)
def load_state_dict(self, state_dict):
if 'config' in state_dict:
self.config.update(ConfigDict(state_dict['config']))
@property
def seq_length(self):
return self.config.seq_length
@property
def tokenizer(self):
return self._tokenizer
@property
def text_processor(self):
return self._text_processor
@property
def dataset(self):
return self._dataset
@property
def vocab_size(self):
return len(self._tokenizer)
class JsonDataset(object):
""" JSON dataset, where each line of the data file contains a JSON
dictionary with text fields.
"""
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.path = ''
config.seq_length = 1024
config.batch_size = 8
config.always_start_with_bos = False
config.start_seek_loc = 0
config.example_index_at_start = 0
config.tokens_count_at_start = 0
config.tokenizer_processes = 1
config.tokenizer_parallel_chunk_size = 32
config.tokenizer_parallel_batch_size = 1024
config.throughput_average_window_size = 200
config.pad = False
config.use_data_sharded_loader = True
config.return_local_batch = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, tokenizer, text_processor, node_info):
self.config = self.get_default_config(config)
assert self.config.path != ''
self._tokenizer = tokenizer
self._text_processor = text_processor
self._node_info = node_info
self._index = self.config.example_index_at_start
self._file_loc = self.config.start_seek_loc
self._total_tokens = self.config.tokens_count_at_start
def parse_json(self, line):
if not line or line == '\n':
return None
try:
data = json.loads(line)
except json.decoder.JSONDecodeError:
print(f'Error parsing json line:\n{line}')
return None
return data
def json_iterator(self):
index, file_loc = self._index, self._file_loc
with open_file(self.config.path, 'r') as fin:
fin.seek(file_loc)
while True:
line = fin.readline()
file_loc = fin.tell()
if not line: # Reached EOF
index = 0
fin.seek(0)
continue
data = self.parse_json(line)
if data is not None and (not self.config.use_data_sharded_loader or index % self._node_info['dp_node_size'] == self._node_info['dp_node_rank']):
# JSON parsing succeeded
yield data, file_loc, index
index += 1
def batched(self, iterator, batch_size):
batch = []
for example in iterator:
batch.append(example)
if len(batch) == batch_size:
yield batch
batch = []
if len(batch) > 0:
yield batch
def parallel_example_iterator(self):
if self.config.tokenizer_processes == 1:
for example, loc, index in self.json_iterator():
self._file_loc = loc
self._index = index
yield self.text_processor((example, loc, index), has_aux=True)
else:
process_pool = Pool(self.config.tokenizer_processes)
batched_iterator = self.batched(
self.json_iterator(), self.config.tokenizer_parallel_batch_size
)
with process_pool as pool:
map_fn = partial(self.text_processor, has_aux=True)
next_batch = pool.map_async(
map_fn, next(batched_iterator),
chunksize=self.config.tokenizer_parallel_chunk_size
)
while True:
current_batch = next_batch
next_batch = pool.map_async(
map_fn, next(batched_iterator),
chunksize=self.config.tokenizer_parallel_chunk_size
)
for example in current_batch.get():
yield example
def __iter__(self):
global_chunk_size = self.config.batch_size * self.config.seq_length
if self.config.use_data_sharded_loader:
local_batch_size = self.config.batch_size // self._node_info['dp_node_size']
else:
local_batch_size = self.config.batch_size
chunk_size = local_batch_size * self.config.seq_length
token_buffer = []
loss_mask_buffer = []
last_time = 0.0
step_times = []
start_time = time.time()
start_tokens = self._total_tokens
for tokens, loss_masks, loc, index in self.parallel_example_iterator():
self._file_loc = loc
self._index = index
if self.config.pad:
tokens = tokens[:self.config.seq_length + 1]
tokens.extend([self._tokenizer.bos_token_id] * (self.config.seq_length + 1 - len(tokens)))
loss_masks = loss_masks[:self.config.seq_length + 1]
loss_masks.extend([0.0] * (self.config.seq_length + 1 - len(loss_masks)))
token_buffer.extend(tokens)
loss_mask_buffer.extend(loss_masks)
while len(token_buffer) > chunk_size + 1:
self._total_tokens += global_chunk_size
step_times.append(time.time() - last_time)
last_time = time.time()
if len(step_times) > self.config.throughput_average_window_size:
step_times = step_times[-self.config.throughput_average_window_size:]
average_throughput = global_chunk_size / np.mean(step_times)
accumulated_throughput = (
(self._total_tokens - start_tokens) / (time.time() - start_time)
)
metrics = {
'dataset_file_loc': loc,
'dataset_example_index': index,
'dataset_total_tokens': self._total_tokens,
'dataset_accumulated_tps': accumulated_throughput,
'dataset_average_tps': average_throughput,
}
batch = {
'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
local_batch_size, -1
),
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
local_batch_size, -1
),
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
local_batch_size, -1
),
}
batch.update({
'input_vision_masks': np.zeros(batch['input_tokens'].shape, dtype=bool),
'target_vision_masks': np.zeros(batch['input_tokens'].shape, dtype=bool),
})
if self.config.always_start_with_bos:
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
if self.config.use_data_sharded_loader and not self.config.return_local_batch:
mesh = self._node_info['mesh']
sp_nodes_size = max(1, mesh.shape['sp'] // jax.local_device_count())
sp_nodes_rank = jax.process_index() % sp_nodes_size
assert self.config.seq_length % sp_nodes_size == 0, (self.config.seq_len, sp_nodes_size)
seq_chunk_size = self.config.seq_length // sp_nodes_size
batch = {k: v[:, sp_nodes_rank*seq_chunk_size:(sp_nodes_rank+1)*seq_chunk_size] for k, v in batch.items()}
batch = host_local_array_to_global_array(batch, self._node_info['mesh'], PS(('dp', 'fsdp'), 'sp'))
yield batch, metrics
if self.config.pad:
token_buffer, loss_mask_buffer = [], []
else:
token_buffer = token_buffer[chunk_size:]
loss_mask_buffer = loss_mask_buffer[chunk_size:]
def _make_callback(self, v):
return lambda index: v[index]
def get_state_dict(self):
return dict(
config=self.config,
index=self._index,
file_loc=self._file_loc,
total_tokens=self._total_tokens,
)
def load_state_dict(self, state_dict):
if 'config' in state_dict:
self.config.update(ConfigDict(state_dict['config']))
self._index = state_dict.get('index', self.config.example_index_at_start)
self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
@property
def seq_length(self):
return self.config.seq_length
@property
def tokenizer(self):
return self._tokenizer
@property
def text_processor(self):
return self._text_processor
@property
def vocab_size(self):
return len(self.tokenizer)
class JsonVisionDataset(object):
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.path = ''
config.seq_length = 384
config.batch_size = 4
config.always_start_with_bos = False
config.start_seek_loc = 0
config.example_index_at_start = 0
config.tokens_count_at_start = 0
config.tokenizer_processes = 1
config.tokenizer_parallel_chunk_size = 32
config.tokenizer_parallel_batch_size = 1024
config.throughput_average_window_size = 200
config.use_data_sharded_loader = True
config.return_local_batch = False
config.mode = 'pad'
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, tokenizer, text_processor, node_info):
self.config = self.get_default_config(config)
assert self.config.path != ''
self._node_info = node_info
self._tokenizer = tokenizer
self._text_processor = text_processor
self._index = self.config.example_index_at_start
self._file_loc = self.config.start_seek_loc
self._total_tokens = 0
def parse_json(self, line):
if not line or line == '\n':
return None
try:
data = json.loads(line)
except json.decoder.JSONDecodeError:
print(f'Error parsing json line:\n{line}')
return None
return data
def json_iterator(self):
index, file_loc = self._index, self._file_loc
with open_file(self.config.path, 'r', block_size=50 * 2 ** 20) as fin:
fin.seek(file_loc)
while True:
line = fin.readline()
file_loc = fin.tell()
if not line: # Reached EOF
index = 0
fin.seek(0)
continue
if not self.config.use_data_sharded_loader or index % self._node_info['dp_node_size'] == self._node_info['dp_node_rank']:
data = self.parse_json(line)
if data is not None:
# JSON parsing succeeded
yield data, file_loc, index
index += 1
def batched(self, iterator, batch_size):
batch = []
for example in iterator:
batch.append(example)
if len(batch) == batch_size:
yield batch
batch = []
if len(batch) > 0:
yield batch
def parallel_example_iterator(self):
if self.config.tokenizer_processes == 1:
for example, loc, index in self.json_iterator():
self._file_loc = loc
self._index = index
yield self.text_processor((example, loc, index), has_aux=True)
else:
process_pool = Pool(self.config.tokenizer_processes)
batched_iterator = self.batched(
self.json_iterator(), self.config.tokenizer_parallel_batch_size
)
with process_pool as pool:
map_fn = partial(self.text_processor, has_aux=True)
next_batch = pool.map_async(
map_fn, next(batched_iterator),
chunksize=self.config.tokenizer_parallel_chunk_size
)
while True:
current_batch = next_batch
next_batch = pool.map_async(
map_fn, next(batched_iterator),
chunksize=self.config.tokenizer_parallel_chunk_size
)
for example in current_batch.get():
yield example
def __iter__(self):
if self.config.mode == 'pad':
fn = self._iter_pad
elif self.config.mode == 'no_pad':
fn = self._iter_no_pad
else:
raise ValueError(f'Unknown mode: {self.config.mode}')
return fn()
def _iter_pad(self):
chunk_size = self.config.batch_size * self.config.seq_length
if self.config.use_data_sharded_loader:
local_batch_size = self.config.batch_size // self._node_info['dp_node_size']
else:
local_batch_size = self.config.batch_size
last_time = 0.0
buffer = []
step_times = []
start_time = time.time()
start_tokens = self._total_tokens
for tokens, loss_masks, vision_masks, keep, loc, index in self.parallel_example_iterator():
if not keep:
continue
self._file_loc = loc
self._index = index
buffer.append((tokens, loss_masks, vision_masks))
while len(buffer) >= local_batch_size:
self._total_tokens += chunk_size
step_times.append(time.time() - last_time)
last_time = time.time()
if len(step_times) > self.config.throughput_average_window_size:
step_times = step_times[-self.config.throughput_average_window_size:]
average_throughput = chunk_size / np.mean(step_times)
accumulated_throughput = (
(self._total_tokens - start_tokens) / (time.time() - start_time)
)
metrics = {
'dataset_file_loc': loc,
'dataset_example_index': index,
'dataset_total_tokens': self._total_tokens,
'dataset_accumulated_tps': accumulated_throughput,
'dataset_average_tps': average_throughput,
}
batch = {
'input_tokens': np.full(
(local_batch_size, self.config.seq_length),
self._tokenizer.bos_token_id,
dtype=np.int32
),
'target_tokens': np.full(
(local_batch_size, self.config.seq_length),
self._tokenizer.bos_token_id,
dtype=np.int32
),
'loss_masks': np.zeros(
(local_batch_size, self.config.seq_length),
dtype=np.float32
),
'input_vision_masks': np.zeros(
(local_batch_size, self.config.seq_length),
dtype=bool
),
'target_vision_masks': np.zeros(
(local_batch_size, self.config.seq_length),
dtype=bool
)
}
for i in range(local_batch_size):
tokens, loss_masks, vision_masks = buffer[i]
if len(tokens) > self.config.seq_length:
tokens = tokens[:self.config.seq_length + 1]
loss_masks = loss_masks[1:self.config.seq_length + 1]
vision_masks = vision_masks[:self.config.seq_length + 1]
input_tokens, target_tokens = tokens[:-1], tokens[1:]
input_vision_masks, target_vision_masks = vision_masks[:-1], vision_masks[1:]
loss_masks = loss_masks[1:]
batch['input_tokens'][i, :len(input_tokens)] = input_tokens
batch['target_tokens'][i, :len(target_tokens)] = target_tokens
batch['input_vision_masks'][i, :len(input_vision_masks)] = input_vision_masks
batch['target_vision_masks'][i, :len(target_vision_masks)] = target_vision_masks
batch['loss_masks'][i, :len(loss_masks)] = loss_masks
if self.config.use_data_sharded_loader and not self.config.return_local_batch:
mesh = self._node_info['mesh']
sp_nodes_size = max(1, mesh.shape['sp'] // jax.local_device_count())
sp_nodes_rank = jax.process_index() % sp_nodes_size
assert self.config.seq_length % sp_nodes_size == 0, (self.config.seq_len, sp_nodes_size)
seq_chunk_size = self.config.seq_length // sp_nodes_size
batch = {k: v[:, sp_nodes_rank*seq_chunk_size:(sp_nodes_rank+1)*seq_chunk_size] for k, v in batch.items()}
batch = host_local_array_to_global_array(batch, self._node_info['mesh'], PS(('dp', 'fsdp'), 'sp'))
yield batch, metrics
buffer = buffer[local_batch_size:]
def _iter_no_pad(self):
global_chunk_size = self.config.batch_size * self.config.seq_length
if self.config.use_data_sharded_loader:
local_batch_size = self.config.batch_size // self._node_info['dp_node_size']
else:
local_batch_size = self.config.batch_size
chunk_size = local_batch_size * self.config.seq_length
token_buffer = []
loss_mask_buffer = []
vision_mask_buffer = []
last_time = 0.0
step_times = []
start_time = time.time()
start_tokens = self._total_tokens
for tokens, loss_masks, vision_masks, keep, loc, index in self.parallel_example_iterator():
if not keep:
continue
self._file_loc = loc
self._index = index
token_buffer.extend(tokens)
loss_mask_buffer.extend(loss_masks)
vision_mask_buffer.extend(vision_masks)
while len(token_buffer) > chunk_size + 1:
self._total_tokens += global_chunk_size
step_times.append(time.time() - last_time)
last_time = time.time()
if len(step_times) > self.config.throughput_average_window_size:
step_times = step_times[-self.config.throughput_average_window_size:]
average_throughput = global_chunk_size / np.mean(step_times)
accumulated_throughput = (
(self._total_tokens - start_tokens) / (time.time() - start_time)
)
metrics = {
'dataset_file_loc': loc,
'dataset_example_index': index,
'dataset_total_tokens': self._total_tokens,
'dataset_accumulated_tps': accumulated_throughput,
'dataset_average_tps': average_throughput,
}
batch = {
'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
local_batch_size, -1
),
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
local_batch_size, -1
),
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
local_batch_size, -1
),
'input_vision_masks': np.array(vision_mask_buffer[:chunk_size], dtype=bool).reshape(
local_batch_size, -1
),
'target_vision_masks': np.array(vision_mask_buffer[1:chunk_size + 1], dtype=bool).reshape(
local_batch_size, -1
),
}
if self.config.use_data_sharded_loader and not self.config.return_local_batch:
mesh = self._node_info['mesh']
sp_nodes_size = max(1, mesh.shape['sp'] // jax.local_device_count())
sp_nodes_rank = jax.process_index() % sp_nodes_size
assert self.config.seq_length % sp_nodes_size == 0, (self.config.seq_len, sp_nodes_size)
seq_chunk_size = self.config.seq_length // sp_nodes_size
batch = {k: v[:, sp_nodes_rank*seq_chunk_size:(sp_nodes_rank+1)*seq_chunk_size] for k, v in batch.items()}
batch = host_local_array_to_global_array(batch, self._node_info['mesh'], PS(('dp', 'fsdp'), 'sp'))
yield batch, metrics
token_buffer = token_buffer[chunk_size:]
loss_mask_buffer = loss_mask_buffer[chunk_size:]
vision_mask_buffer = vision_mask_buffer[chunk_size:]
def _make_callback(self, v):
return lambda index: v[index]
def get_state_dict(self):
return dict(
config=self.config,
index=self._index,
file_loc=self._file_loc,
total_tokens=self._total_tokens,
)
def load_state_dict(self, state_dict):
if 'config' in state_dict:
self.config.update(ConfigDict(state_dict['config']))
self._index = state_dict.get('index', self.config.example_index_at_start)
self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
@property
def seq_length(self):
return self.config.seq_length
@property
def tokenizer(self):
return self._tokenizer
@property
def text_processor(self):
return self._text_processor
@property
def vocab_size(self):
return len(self._tokenizer)
================================================
FILE: lwm/llama.py
================================================
import os
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import json
import tempfile
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
from jax.sharding import PartitionSpec as PS
from jax.experimental.shard_map import shard_map
import flax.linen as nn
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.linen import partitioning as nn_partitioning
import sentencepiece as spm
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from transformers.modeling_flax_utils import FlaxPreTrainedModel
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ml_collections import ConfigDict
from tux import function_args_to_config, load_pickle, open_file, with_sharding_constraint, get_jax_mesh
from ringattention import ringattention, blockwise_feedforward, ringattention_jax, ringattention_inference
LLAMA_STANDARD_CONFIGS = {
'200m': {
'vocab_size': 32000,
'hidden_size': 1024,
'intermediate_size': 2048,
'num_hidden_layers': 14,
'num_attention_heads': 8,
'max_sequence_length': 2048,
'initializer_range': 0.02,
'rms_norm_eps': 1e-6,
'use_cache': True,
'tie_word_embeddings': False,
},
'1b': {
'vocab_size': 32000,
'hidden_size': 2048,
'intermediate_size': 5504,
'num_hidden_layers': 22,
'num_attention_heads': 16,
'max_sequence_length': 2048,
'initializer_range': 0.02,
'rms_norm_eps': 1e-6,
'use_cache': True,
'tie_word_embeddings': False,
},
'3b': {
'vocab_size': 32000,
'hidden_size': 3200,
'intermediate_size': 8640,
'num_hidden_layers': 26,
'num_attention_heads': 32,
'max_sequence_length': 2048,
'initializer_range': 0.02,
'rms_norm_eps': 1e-6,
'use_cache': True,
'tie_word_embeddings': False,
},
'7b': {
'vocab_size': 32000,
'hidden_size': 4096,
'intermediate_size': 11008,
'num_hidden_layers': 32,
'num_attention_heads': 32,
'max_sequence_length': 4096,
'initializer_range': 0.02,
'rms_norm_eps': 1e-6,
'use_cache': True,
'tie_word_embeddings': False,
},
'13b': {
'vocab_size': 32000,
'hidden_size': 5120,
'intermediate_size': 13824,
'num_hidden_layers': 40,
'num_attention_heads': 40,
'max_sequence_length': 2048,
'initializer_range': 0.02,
'rms_norm_eps': 1e-6,
'use_cache': True,
'tie_word_embeddings': False,
},
'30b': {
'vocab_size': 32000,
'hidden_size': 6656,
'intermediate_size': 17920,
'num_hidden_layers': 60,
'num_attention_heads': 52,
'max_sequence_length': 2048,
'initializer_range': 0.02,
'rms_norm_eps': 1e-6,
'use_cache': True,
'tie_word_embeddings': False,
},
'65b': {
'vocab_size': 32000,
'hidden_size': 8192,
'intermediate_size': 22016,
'num_hidden_layers': 80,
'num_attention_heads': 64,
'max_sequence_length': 2048,
'initializer_range': 0.02,
'rms_norm_eps': 1e-5,
'use_cache': True,
'tie_word_embeddings': False,
},
'debug': { # A small model for debugging
'vocab_size': 32000,
'hidden_size': 256,
'intermediate_size': 256,
'num_hidden_layers': 2,
'num_attention_heads': 2,
'max_sequence_length': 2048,
'initializer_range': 0.02,
'rms_norm_eps': 1e-6,
'use_cache': True,
'tie_word_embeddings': False,
},
}
class LLaMAConfig(PretrainedConfig):
model_type = "llama"
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
max_sequence_length=4096,
rms_norm_eps=1e-6,
initializer_range=0.02,
use_cache=True,
bos_token_id=0,
eos_token_id=1,
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0,
tie_word_embeddings=False,
scan_attention=True,
scan_mlp=True,
scan_query_chunk_size=1024,
scan_key_chunk_size=1024,
scan_mlp_chunk_size=1024,
scan_layers=True,
param_scan_axis=0,
mesh_dim=None,
theta=10000,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_sequence_length = max_sequence_length
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.scan_attention = scan_attention
self.scan_mlp = scan_mlp
self.scan_query_chunk_size = scan_query_chunk_size
self.scan_key_chunk_size = scan_key_chunk_size
self.scan_mlp_chunk_size = scan_mlp_chunk_size
self.scan_layers = scan_layers
self.param_scan_axis = param_scan_axis
self.mesh_dim = mesh_dim
self.theta = theta
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@classmethod
def get_default_config(cls, updates=None):
config = function_args_to_config(cls.__init__)
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@staticmethod
def get_jax_mesh(axis_dims):
return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'tp', 'sp'))
@staticmethod
def get_ranks_and_size(mesh):
out = dict(mesh=mesh)
mp_size = mesh.shape['tp'] * mesh.shape['sp']
mp_node_size = max(1, mp_size // jax.local_device_count())
dp_node_size = jax.process_count() // mp_node_size
out.update(mp_node_size=mp_node_size,
dp_node_size=dp_node_size)
dp_node_rank = jax.process_index() // mp_node_size
mp_node_rank = jax.process_index() % mp_node_size
out.update(dp_node_rank=dp_node_rank,
mp_node_rank=mp_node_rank)
return out
@staticmethod
def get_partition_rules(scan_layers=False, scan_axis=0):
"""Parition rules are orderd, so that the beginning rules match first."""
if scan_layers:
if scan_axis == 0:
return (
# embeddings
("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
# atention
("attention/(wq|wk|wv)/kernel", PS(None, ("fsdp", "sp"), "tp")),
("attention/wo/kernel", PS(None, "tp", ("fsdp", "sp"))),
# mlp
("feed_forward/w1/kernel", PS(None, ("fsdp", "sp"), "tp")),
("feed_forward/w2/kernel", PS(None, "tp", ("fsdp", "sp"))),
("feed_forward/w3/kernel", PS(None, ("fsdp", "sp"), "tp")),
# layer norms
("attention_norm/kernel", PS(None, None)),
("ffn_norm/kernel", PS(None, None)),
# output head
("transformer/ln_f/kernel", PS(None)),
("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
('.*', PS(None)),
)
elif scan_axis == 1:
return (
# embeddings
("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
# atention
("attention/(wq|wk|wv)/kernel", PS(("fsdp", "sp"), None, "tp")),
("attention/wo/kernel", PS("tp", None, ("fsdp", "sp"))),
# mlp
("feed_forward/w1/kernel", PS(("fsdp", "sp"), None, "tp")),
("feed_forward/w2/kernel", PS("tp", None, ("fsdp", "sp"))),
("feed_forward/w3/kernel", PS(("fsdp", "sp"), None, "tp")),
# layer norms
("attention_norm/kernel", PS(None, None)),
("ffn_norm/kernel", PS(None, None)),
# output head
("transformer/ln_f/kernel", PS(None)),
("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
('.*', PS(None)),
)
else:
raise ValueError(f"Invalid scan_axis {scan_axis}")
else:
return (
# embeddings
("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
# atention
("attention/(wq|wk|wv)/kernel", PS(("fsdp", "sp"), "tp")),
("attention/wo/kernel", PS("tp", ("fsdp", "sp"))),
# mlp
("feed_forward/w1/kernel", PS(("fsdp", "sp"), "tp")),
("feed_forward/w2/kernel", PS("tp", ("fsdp", "sp"))),
("feed_forward/w3/kernel", PS(("fsdp", "sp"), "tp")),
# layer norms
("attention_norm/kernel", PS(None)),
("ffn_norm/kernel", PS(None)),
# output head
("transformer/ln_f/kernel", PS(None)),
("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
('.*', PS(None)),
)
@staticmethod
def get_weight_decay_exclusions():
return tuple()
@staticmethod
def get_frozen_param_exclusions(freeze_base):
if freeze_base:
return ("vte", "vision_head")
else:
return tuple()
@staticmethod
def rng_keys():
return ('params', 'dropout')
@classmethod
def load_config(cls, path):
if path in LLAMA_STANDARD_CONFIGS:
return cls.from_dict(LLAMA_STANDARD_CONFIGS[path])
load_type, load_path = path.split('::', 1)
if load_type == 'pickle':
return cls.from_dict(load_pickle(load_path)['llama_config'])
elif load_type == 'json':
with open_file(load_path, 'r') as fin:
raw_config = fin.read()
return cls.from_dict(json.loads(raw_config))
else:
raise ValueError(f'Unsupported load config type: {load_type}')
remat = nn_partitioning.remat
logger = logging.get_logger(__name__)
class RMSNorm(nn.Module):
dim: int
eps: float=1e-6
dtype: jnp.dtype=jnp.float32
param_dtype: jnp.dtype=jnp.float32
def setup(self) -> None:
self.weight = self.param(
'kernel',
nn.initializers.ones,
(self.dim,),
self.param_dtype,
)
def _norm(self, x: jnp.ndarray) -> jnp.ndarray:
return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps)
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = x.astype(jnp.promote_types(self.dtype, jnp.float32))
output = self._norm(x).astype(self.dtype)
weight = jnp.asarray(self.weight, self.dtype)
return output * weight
def precompute_freqs_cis(dim: int, max_position_embedding: int, theta: float=10000.0, dtype: jnp.dtype=jnp.float32) -> jnp.ndarray:
freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
t = np.arange(max_position_embedding) # type: ignore
freqs = np.outer(t, freqs).astype(dtype) # type: ignore
sin, cos = np.sin(freqs), np.cos(freqs)
freqs_cis = np.complex64(cos + 1j * sin)
return jnp.asarray(freqs_cis)
def apply_rotary_emb(
xq: jnp.ndarray,
xk: jnp.ndarray,
freqs_cis: jnp.ndarray,
dtype: jnp.dtype=jnp.float32,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
# add head dim
freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:]))
xq_out = xq_ * freqs_cis
xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1)
xk_out = xk_ * freqs_cis
xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1)
return xq_out.astype(dtype), xk_out.astype(dtype)
class FlaxLLaMAAttention(nn.Module):
config: LLaMAConfig
dtype: jnp.dtype=jnp.float32
param_dtype: jnp.dtype=jnp.float32
precision: Optional[Union[jax.lax.Precision, str]]=None
def setup(self):
config = self.config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.wq = nn.Dense(
config.num_attention_heads*self.head_dim,
dtype=self.dtype,
param_dtype=self.param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
precision=self.precision,
)
self.wk = nn.Dense(
config.num_attention_heads*self.head_dim,
dtype=self.dtype,
param_dtype=self.param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
precision=self.precision,
)
self.wv = nn.Dense(
config.num_attention_heads*self.head_dim,
dtype=self.dtype,
param_dtype=self.param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
precision=self.precision,
)
self.wo = nn.Dense(
config.hidden_size,
dtype=self.dtype,
param_dtype=self.param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
precision=self.precision,
)
self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_sequence_length), dtype="bool"), dtype="bool")
self.freqs_cis = precompute_freqs_cis(
self.head_dim,
config.max_sequence_length,
theta=config.theta,
dtype=self.dtype,
)
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
@nn.compact
def _concatenate_to_cache(self, key, value, query, attention_mask):
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# update key, value caches with our new 1d spatial slices
cur_index = cache_index.value
if query.shape[1] == 1:
mesh = LLaMAConfig.get_jax_mesh(self.config.mesh_dim)
def fn(cached_key, cached_value, key, value, cur_index):
assert key.shape[1] == 1 and value.shape[1] == 1, (key.shape, value.shape)
sp_size = max_length // mesh.shape['sp']
axis_index = jax.lax.axis_index('sp')
cur_index = cur_index - axis_index * sp_size
key, value = jax.lax.cond(
jnp.logical_and(cur_index >= 0, cur_index < sp_size),
lambda: (
cached_key.at[:, cur_index].set(key[:, -1]),
cached_value.at[:, cur_index].set(value[:, -1]),
),
lambda: (cached_key, cached_value),
)
return key, value
fn = shard_map(
fn, mesh=mesh,
in_specs=(
PS(('dp', 'fsdp'), 'sp', 'tp', None),
PS(('dp', 'fsdp'), 'sp', 'tp', None),
PS(('dp', 'fsdp'), None, 'tp', None),
PS(('dp', 'fsdp'), None, 'tp', None),
PS()
),
out_specs=(
PS(('dp', 'fsdp'), 'sp', 'tp', None),
PS(('dp', 'fsdp'), 'sp', 'tp', None)
),
check_rep=False
)
key, value = fn(cached_key.value, cached_value.value, key, value, cur_index)
else:
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
return key, value, attention_mask
def __call__(
self,
hidden_states,
attention_mask,
segment_ids,
position_ids,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
):
xq, xk, xv = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
if xq.shape[1] == 1:
xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), None, "tp"))
else:
xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), "sp", "tp"))
xk = with_sharding_constraint(xk, PS(("dp", "fsdp"), "sp", "tp"))
xv = with_sharding_constraint(xv, PS(("dp", "fsdp"), "sp", "tp"))
xq = self._split_heads(xq)
xk = self._split_heads(xk)
xv = self._split_heads(xv)
freqs_cis = jnp.take(self.freqs_cis, position_ids, axis=0)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype)
dropout_rng = None
if not deterministic and self.config.attn_pdrop > 0.0:
dropout_rng = self.make_rng("dropout")
if self.config.scan_attention and xq.shape[1] > max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size):
# attention mask without nxn materlization, blockwise_attn will handle the rest
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
if self.has_variable("cache", "cached_key") or init_cache:
xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask)
# transform boolean mask into float mask
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
attn_weights = None
ring_attention_sharded = shard_map(
partial(
ringattention,
axis_name="sp",
float32_logits=True,
cache_idx=None,
blockwise_kwargs=dict(
causal_block_size=1,
deterministic=deterministic,
dropout_rng=dropout_rng,
attn_pdrop=self.config.attn_pdrop,
query_chunk_size=self.config.scan_query_chunk_size,
key_chunk_size=self.config.scan_key_chunk_size,
dtype=self.dtype,
policy=jax.checkpoint_policies.nothing_saveable,
precision=self.precision,
prevent_cse=not self.config.scan_layers,
)
),
mesh=LLaMAConfig.get_jax_mesh(self.config.mesh_dim),
in_specs=(
PS(("dp", "fsdp"), "sp", "tp", None),
PS(("dp", "fsdp"), "sp", "tp", None),
PS(("dp", "fsdp"), "sp", "tp", None),
PS(("dp", "fsdp"), None, None, None),
PS(("dp", "fsdp"), None),
),
out_specs=PS(("dp", "fsdp"), "sp", "tp", None),
check_rep=False
)
attn_output = ring_attention_sharded(xq, xk, xv, attention_bias, segment_ids)
attn_output = with_sharding_constraint(attn_output, PS(("dp", "fsdp"), "sp", "tp", None))
else:
query_length, key_length = xq.shape[1], xk.shape[1]
if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
causal_mask = jnp.arange(max_decoder_length)[None] <= (jnp.arange(query_length) + mask_shift)[:, None]
causal_mask = causal_mask[None, None]
segment_mask = None
else:
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
if segment_ids is not None:
segment_mask = segment_ids[:, :, None] == segment_ids[:, None, :]
segment_mask = segment_mask[:, None]
else:
segment_mask = None
batch_size = hidden_states.shape[0]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask, segment_mask)
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.has_variable("cache", "cached_key") or init_cache:
xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask)
q_sp_dim = None if xq.shape[1] == 1 else 'sp'
attn_weights = None
ring_attention_sharded = shard_map(
partial(ringattention_inference, axis_name="sp"), mesh=LLaMAConfig.get_jax_mesh(self.config.mesh_dim),
in_specs=(
PS(("dp", "fsdp"), q_sp_dim, "tp", None),
PS(("dp", "fsdp"), "sp", "tp", None),
PS(("dp", "fsdp"), "sp", "tp", None),
PS(("dp", "fsdp"), None, q_sp_dim, None)
),
out_specs=PS(("dp", "fsdp"), q_sp_dim, "tp", None),
check_rep=False
)
attn_output = ring_attention_sharded(
xq, xk, xv, attention_mask
)
attn_output = self._merge_heads(attn_output)
attn_output = self.wo(attn_output)
attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
class FlaxLLaMAMLP(nn.Module):
config: LLaMAConfig
dtype: jnp.dtype=jnp.float32
param_dtype: jnp.dtype=jnp.float32
precision: Optional[Union[jax.lax.Precision, str]]=None
def setup(self) -> None:
config = self.config
self.w1 = nn.Dense(
config.intermediate_size,
dtype=self.dtype,
param_dtype=self.param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
precision=self.precision,
)
self.w2 = nn.Dense(
config.hidden_size,
dtype=self.dtype,
param_dtype=self.param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
precision=self.precision,
)
self.w3 = nn.Dense(
config.intermediate_size,
dtype=self.dtype,
param_dtype=self.param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
precision=self.precision,
)
self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
x = self.w2(nn.silu(self.w1(x)) * self.w3(x))
x = self.dropout(x, deterministic=deterministic)
return x
class FlaxLLaMABlock(nn.Module):
config: LLaMAConfig
dtype: jnp.dtype=jnp.float32
param_dtype: jnp.dtype=jnp.float32
precision: Optional[Union[jax.lax.Precision, str]]=None
def setup(self) -> None:
attention_module = FlaxLLaMAAttention
mlp_module = FlaxLLaMAMLP
if self.config.scan_mlp:
mlp_module = remat(
mlp_module, static_argnums=(1,),
policy=jax.checkpoint_policies.nothing_saveable,
prevent_cse=not self.config.scan_layers,
)
self.attention = attention_module(
self.config,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)
self.feed_forward = mlp_module(
self.config,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)
self.attention_norm = RMSNorm(
self.config.hidden_size,
eps=self.config.rms_norm_eps,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
self.ffn_norm = RMSNorm(
self.config.hidden_size,
eps=self.config.rms_norm_eps,
dtype=self.dtype,
param_dtype=self.param_dtype,
)
def __call__(
self,
hidden_states,
attention_mask=None,
segment_ids=None,
position_ids=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
):
attn_outputs = self.attention(
self.attention_norm(hidden_states),
attention_mask,
segment_ids,
position_ids,
deterministic,
init_cache,
output_attentions,
)
attn_output = attn_outputs[0]
hidden_states = hidden_states + attn_output
feed_forward_input = self.ffn_norm(hidden_states)
if self.config.scan_mlp and hidden_states.shape[1] >= self.config.scan_mlp_chunk_size:
feed_forward_hidden_states = blockwise_feedforward(
self.feed_forward,
feed_forward_input,
self.config.scan_mlp_chunk_size,
pre_remat=True,
)
else:
feed_forward_hidden_states = self.feed_forward(feed_forward_input, deterministic)
feed_forward_hidden_states = with_sharding_constraint(feed_forward_hidden_states, PS(("dp", "fsdp"), None, "tp"))
hidden_states = hidden_states + feed_forward_hidden_states
outputs = hidden_states
if self.config.scan_layers:
outputs = (outputs, None)
return outputs
class FlaxLLaMAPreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = LLaMAConfig
base_model_prefix = "transformer"
module_class: nn.Module = None
def __init__(
self,
config: LLaMAConfig,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
segment_ids = None
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
if self.config.add_cross_attention:
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
encoder_attention_mask = attention_mask
module_init_outputs = self.module.init(
rngs,
input_ids,
attention_mask,
segment_ids,
position_ids,
encoder_hidden_states,
encoder_attention_mask,
return_dict=False,
)
else:
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, segment_ids, position_ids, return_dict=False)
random_params = module_init_outputs["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
"""
# init input variables to retrieve cache
input_ids = jnp.ones((batch_size, max_length))
attention_mask = jnp.ones_like(input_ids)
segment_ids = None
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, segment_ids, position_ids, return_dict=False, init_cache=True
)
return init_variables["cache"]
@add_start_docstrings_to_model_forward("")
def __call__(
self,
input_ids,
attention_mask=None,
segment_ids=None,
position_ids=None,
params: dict = None,
past_key_values: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
batch_size, sequence_length = input_ids.shape
if position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
segment_ids,
jnp.array(position_ids, dtype="i4"),
not train,
False,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
mutable=mutable,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past_key_values = outputs
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past_key_values = outputs
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
return outputs
class FlaxLLaMABlockCollection(nn.Module):
config: LLaMAConfig
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype=jnp.float32
precision: Optional[Union[jax.lax.Precision, str]]=None
@nn.compact
def __call__(
self,
hidden_states,
attention_mask=None,
segment_ids=None,
position_ids=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
block = FlaxLLaMABlock
if self.config.scan_layers:
initializing = self.is_mutable_collection('params')
params_spec = (
self.config.param_scan_axis if initializing else
nn_partitioning.ScanIn(self.config.param_scan_axis))
cache_spec = 0
hidden_states, _ = nn.scan(
block,
variable_axes={
'params': params_spec,
'cache': cache_spec,
'intermediates': 0
},
split_rngs={
'params': True,
'dropout': True
},
in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast),
length=self.config.num_hidden_layers,
metadata_params={nn.PARTITION_NAME: 'scan_decoder_layer'},
)(self.config, name='scan_decoder', dtype=self.dtype, param_dtype=self.param_dtype,)(
hidden_states,
attention_mask,
segment_ids,
position_ids,
deterministic,
init_cache,
output_attentions,
)
else:
blocks = [
block(
self.config,
name=str(i),
dtype=self.dtype,
param_dtype=self.param_dtype,
) for i in range(self.config.num_hidden_layers)
]
for block in blocks:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = block(
hidden_states,
attention_mask,
segment_ids,
position_ids,
deterministic,
init_cache,
output_attentions,
)
hidden_states = layer_outputs
if output_attentions:
all_attentions += (layer_outputs[1],)
outputs = (hidden_states, all_hidden_states, all_attentions)
return outputs
class FlaxLLaMAModule(nn.Module):
config: LLaMAConfig
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype=jnp.float32
precision: Optional[Union[jax.lax.Precision, str]]=None
def setup(self):
self.embed_dim = self.config.hidden_size
self.wte = nn.Embed(
self.config.vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
param_dtype=self.param_dtype,
)
self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
self.h = FlaxLLaMABlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision)
self.ln_f = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype)
def __call__(
self,
input_ids,
attention_mask,
segment_ids,
position_ids,
deterministic=True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
input_embeds = self.wte(input_ids.astype("i4"))
assert input_embeds.shape[1] <= self.config.max_sequence_length, f"Input sequence length {input_embeds.shape[1]} larger than max supported sequence length {self.config.max_sequence_length}"
hidden_states = self.dropout(input_embeds, deterministic=deterministic)
outputs = self.h(
hidden_states,
attention_mask,
segment_ids=segment_ids,
position_ids=position_ids,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = outputs[1] + (hidden_states,)
outputs = (hidden_states, all_hidden_states) + outputs[2:]
else:
outputs = (hidden_states,) + outputs[1:]
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=outputs[1],
attentions=outputs[-1],
)
class FlaxLLaMAForCausalLMModule(nn.Module):
config: LLaMAConfig
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype=jnp.float32
precision: Optional[Union[jax.lax.Precision, str]]=None
def setup(self):
self.transformer = FlaxLLaMAModule(self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
self.config.vocab_size,
dtype=self.dtype,
param_dtype=self.param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
precision=self.precision,
)
def __call__(
self,
input_ids,
attention_mask=None,
segment_ids=None,
position_ids=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
batch_size, seq_length = input_ids.shape
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = jnp.arange(seq_length, dtype=jnp.int32)[None].repeat(batch_size, axis=0)
outputs = self.transformer(
input_ids,
attention_mask,
segment_ids,
position_ids,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
else:
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return (lm_logits,) + outputs[1:]
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
@add_start_docstrings("", "")
class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel):
module_class = FlaxLLaMAForCausalLMModule
def prepare_inputs_for_generation(
self, input_ids, max_length,
attention_mask: Optional[jax.Array] = None,
):
# initializing the cache
batch_size, seq_length = input_ids.shape
past_key_values = self.init_cache(batch_size, max_length)
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
position_ids = attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
"position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
return model_kwargs
================================================
FILE: lwm/train.py
================================================
import pprint
import os
from functools import partial
from tqdm import tqdm, trange
import numpy as np
from absl.app import run
import absl.logging as logging
import tux
import jax
import flax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
from flax.training.train_state import TrainState
from transformers import AutoTokenizer
from lwm.data import DatasetFactory
from tux import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
set_random_seed, average_metrics, get_mask,
make_shard_and_gather_fns, with_sharding_constraint, define_flags_with_default,
OptimizerFactory, StreamingCheckpointer
)
from lwm.llama import LLaMAConfig, FlaxLLaMAForCausalLMModule
from lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLMModule
FLAGS, FLAGS_DEF = define_flags_with_default(
modality='text',
use_data_sharded_loader=True,
seed=42,
mesh_dim='1,-1,1,1',
dtype='fp32',
total_steps=10000,
load_llama_config='',
update_llama_config='',
load_checkpoint='',
load_dataset_state='',
log_freq=50,
save_model_freq=0,
save_milestone_freq=0,
eval_steps=0,
tokenizer='LargeWorldModel/LWM-Text-1M',
train_dataset=DatasetFactory.get_default_config(),
eval_dataset=DatasetFactory.get_default_config(),
optimizer=OptimizerFactory.get_default_config(),
checkpointer=StreamingCheckpointer.get_default_config(),
llama=VideoLLaMAConfig.get_default_config(),
logger=tux.WandBLogger.get_default_config(),
log_all_worker=False,
jax_distributed=JaxDistributedConfig.get_default_config(),
autoresume=False,
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
variant = tux.get_user_flags(FLAGS, FLAGS_DEF)
flags_config_dict = tux.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
logger = tux.WandBLogger(
config=FLAGS.logger,
variant=variant,
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
)
set_random_seed(FLAGS.seed)
if jax.process_index() == 0:
output_dir = logger.output_dir
else:
output_dir = os.path.join(logger.output_dir, logger.experiment_id)
if FLAGS.modality == 'text':
config_cls = LLaMAConfig
llama_cls = FlaxLLaMAForCausalLMModule
elif FLAGS.modality == 'vision,text':
config_cls = VideoLLaMAConfig
llama_cls = FlaxVideoLLaMAForCausalLMModule
else:
raise ValueError(f"Unsupported modality: {FLAGS.modality}")
mesh = config_cls.get_jax_mesh(FLAGS.mesh_dim)
node_info = config_cls.get_ranks_and_size(mesh)
tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer)
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer, node_info=node_info)
if FLAGS.autoresume and tux.check_exists(output_dir):
logging.info('Found existing output. Resuming dataset from latest checkpoint...')
resume_path = f"{output_dir}/dataset.pkl"
dataset.load_state_dict(tux.load_pickle(resume_path))
elif FLAGS.load_dataset_state != '':
dataset.load_state_dict(tux.load_pickle(FLAGS.load_dataset_state))
if FLAGS.eval_steps > 0:
eval_dataset = DatasetFactory.load_dataset(
FLAGS.eval_dataset, dataset.tokenizer
)
eval_iterator = iter(eval_dataset)
seq_length = dataset.seq_length
if FLAGS.load_llama_config != '':
llama_config = config_cls.load_config(FLAGS.load_llama_config)
updates = config_cls(**FLAGS.llama)
llama_config.update(dict(
scan_attention=updates.scan_attention,
scan_mlp=updates.scan_mlp,
scan_query_chunk_size=updates.scan_query_chunk_size,
scan_key_chunk_size=updates.scan_key_chunk_size,
scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
scan_layers=updates.scan_layers,
param_scan_axis=updates.param_scan_axis,
))
else:
llama_config = config_cls(**FLAGS.llama)
if FLAGS.update_llama_config != '':
llama_config.update(dict(eval(FLAGS.update_llama_config)))
llama_config.update(dict(
bos_token_id=dataset.tokenizer.bos_token_id,
eos_token_id=dataset.tokenizer.eos_token_id,
))
if llama_config.vocab_size < dataset.vocab_size:
llama_config.update(dict(vocab_size=dataset.vocab_size))
llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))
model = llama_cls(
llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
)
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
FLAGS.optimizer,
get_mask(config_cls.get_weight_decay_exclusions()),
None,
)
def create_trainstate_from_params(params):
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def init_fn(rng):
rng_generator = JaxRNG(rng)
batch = 512
if FLAGS.modality == 'text':
params = model.init(
input_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32),
position_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32),
attention_mask=jnp.ones((batch, seq_length), dtype=jnp.int32),
rngs=rng_generator(llama_config.rng_keys()),
)
elif FLAGS.modality == 'vision,text':
params = model.init(
input_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32),
vision_masks=jnp.zeros((batch, seq_length), dtype=bool),
position_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32),
attention_mask=jnp.ones((batch, seq_length), dtype=jnp.int32),
rngs=rng_generator(llama_config.rng_keys()),
)
else:
raise ValueError(f"Unsupported modality: {FLAGS.modality}")
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def train_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
def loss_and_accuracy(params):
if FLAGS.modality == 'text':
logits = model.apply(
params,
batch['input_tokens'],
deterministic=False,
rngs=rng_generator(llama_config.rng_keys()),
).logits
loss, acc = cross_entropy_loss_and_accuracy(
logits,
batch['target_tokens'],
batch['loss_masks']
)
metrics = dict(acc=acc)
return loss, metrics
elif FLAGS.modality == 'vision,text':
vision_logits, text_logits = model.apply(
params,
batch['input_tokens'],
batch['input_vision_masks'],
deterministic=False,
rngs=rng_generator(llama_config.rng_keys()),
).logits
vision_loss, vision_acc = cross_entropy_loss_and_accuracy(
vision_logits,
jnp.where(batch['target_vision_masks'], batch['target_tokens'], 0),
batch['loss_masks'] * batch['target_vision_masks']
)
text_loss, text_acc = cross_entropy_loss_and_accuracy(
text_logits,
jnp.where(batch['target_vision_masks'], 0, batch['target_tokens']),
batch['loss_masks'] * (1.0 - batch['target_vision_masks'])
)
loss = 0.5 * (vision_loss + text_loss)
metrics = dict(
vision_loss=vision_loss,
vision_acc=vision_acc,
text_loss=text_loss,
text_acc=text_acc,
)
else:
raise ValueError(f"Unsupported modality: {FLAGS.modality}")
return loss, metrics
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
(loss, loss_metrics), grads = grad_fn(train_state.params)
train_state = train_state.apply_gradients(grads=grads)
metrics = dict(
loss=loss,
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
param_norm=global_norm(train_state.params),
gradient_norm=global_norm(grads),
**loss_metrics
)
return train_state, rng_generator(), metrics
def eval_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
if FLAGS.modality == 'text':
logits = model.apply(
train_state.params,
batch['input_tokens'],
deterministic=True,
rngs=rng_generator(llama_config.rng_keys()),
).logits
loss, acc = cross_entropy_loss_and_accuracy(
logits,
batch['target_tokens'],
batch['loss_masks']
)
metrics = dict(
eval_loss=loss,
eval_acc=acc,
)
elif FLAGS.modality == 'vision,text':
vision_logits, text_logits = model.apply(
train_state.params,
batch['input_tokens'],
batch['input_vision_masks'],
deterministic=True,
rngs=rng_generator(llama_config.rng_keys()),
).logits
vision_loss, vision_acc = cross_entropy_loss_and_accuracy(
vision_logits,
jnp.where(batch['target_vision_masks'], batch['target_tokens'], 0),
batch['loss_masks'] * batch['target_vision_masks']
)
text_loss, text_acc = cross_entropy_loss_and_accuracy(
text_logits,
jnp.where(batch['target_vision_masks'], 0, batch['target_tokens']),
batch['loss_masks'] * (1.0 - batch['target_vision_masks'])
)
loss = 0.5 * (vision_loss + text_loss)
metrics = dict(
eval_loss=loss,
eval_vision_accuracy=vision_acc,
eval_vision_loss=vision_loss,
eval_text_accuracy=text_acc,
eval_text_loss=text_loss,
)
return rng_generator(), metrics
train_state_shapes = jax.eval_shape(init_fn, next_rng())
train_state_partition = match_partition_rules(
config_cls.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), train_state_shapes
)
shard_fns, gather_fns = make_shard_and_gather_fns(
train_state_partition, train_state_shapes
)
checkpointer = StreamingCheckpointer(
FLAGS.checkpointer, logger.output_dir,
enable=jax.process_index() == 0,
)
sharded_init_fn = pjit(
init_fn,
in_shardings=PS(),
out_shardings=train_state_partition
)
sharded_create_trainstate_from_params = pjit(
create_trainstate_from_params,
in_shardings=(train_state_partition.params, ),
out_shardings=train_state_partition,
donate_argnums=(0, ),
)
if FLAGS.use_data_sharded_loader:
batch_spec = PS(('dp', 'fsdp'), 'sp')
else:
batch_spec = PS()
sharded_train_step = pjit(
train_step,
in_shardings=(train_state_partition, PS(), batch_spec),
out_shardings=(train_state_partition, PS(), PS()),
donate_argnums=(0, 1),
)
sharded_eval_step = pjit(
eval_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(PS(), PS()),
donate_argnums=(1,),
)
def save_checkpoint(train_state, milestone=False):
step = int(jax.device_get(train_state.step))
metadata = dict(
step=step,
variant=variant,
flags=flags_config_dict,
llama_config=llama_config.to_dict(),
)
checkpointer.save_all(
train_state=train_state,
gather_fns=gather_fns,
metadata=metadata,
dataset=dataset.get_state_dict(),
milestone=milestone,
)
with mesh:
train_state, restored_params = None, None
if FLAGS.autoresume and tux.check_exists(output_dir):
logging.info('Found existing output. Resuming model from latest checkpoint...')
resume_path = f"trainstate::{output_dir}/streaming_train_state"
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
resume_path, train_state_shapes, shard_fns, max_buffer_size=32 * 2 ** 30
)
elif FLAGS.load_checkpoint != '':
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, train_state_shapes, shard_fns, max_buffer_size=32 * 2 ** 30
)
if train_state is None and restored_params is None:
# Initialize from scratch
train_state = sharded_init_fn(next_rng())
elif train_state is None and restored_params is not None:
# Restore from params but initialize train_state
train_state = sharded_create_trainstate_from_params(flax.core.unfreeze(restored_params))
del restored_params
start_step = int(jax.device_get(train_state.step))
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
sharded_rng = next_rng()
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
train_state, sharded_rng, metrics = sharded_train_step(
train_state, sharded_rng, batch
)
if step % FLAGS.log_freq == 0:
if FLAGS.eval_steps > 0:
eval_metric_list = []
for _ in range(FLAGS.eval_steps):
eval_batch, _ = next(eval_iterator)
sharded_rng, eval_metrics = sharded_eval_step(
train_state, sharded_rng, eval_batch
)
eval_metrics = jax.device_get(eval_metrics)
eval_metric_list.append(eval_metrics)
metrics.update(average_metrics(eval_metric_list))
log_metrics = {"step": step}
log_metrics.update(metrics)
log_metrics.update(dataset_metrics)
log_metrics = jax.device_get(log_metrics)
logger.log(log_metrics)
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
save_checkpoint(train_state, milestone=True)
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
save_checkpoint(train_state)
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
if __name__ == "__main__":
run(main)
================================================
FILE: lwm/vision_chat.py
================================================
from absl.app import run
import math
from tqdm import tqdm
from PIL import Image
import decord
from functools import cached_property
import numpy as np
import jax
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
from transformers import GenerationConfig, AutoTokenizer
from tux import (
define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,
set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng,
match_partition_rules, make_shard_and_gather_fns,
with_sharding_constraint, tree_apply, open_file
)
from lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM
from lwm.vqgan import VQGAN
FLAGS, FLAGS_DEF = define_flags_with_default(
prompt="",
input_file="",
vqgan_checkpoint="",
temperature=0.2,
max_n_frames=8,
seed=1234,
mesh_dim='1,-1,1,1',
dtype='fp32',
load_llama_config='',
update_llama_config='',
load_checkpoint='',
tokenizer='LargeWorldModel/LWM-Text-1M',
llama=VideoLLaMAConfig.get_default_config(),
jax_distributed=JaxDistributedConfig.get_default_config(),
)
class Sampler:
def __init__(self):
self.mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
self.vqgan = VQGAN(FLAGS.vqgan_checkpoint, replicate=False)
self.prefix_tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer, truncation_side='left', padding_side='left')
self.tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer)
self.n_tokens_per_frame = 257
self.min_buffer_size = 256
self.sharded_rng = next_rng()
self._load_model()
@property
def block_size(self):
return max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) * self.mesh.shape['sp']
@property
def data_dim(self):
return self.mesh.shape['dp'] * self.mesh.shape['fsdp']
def _process_frame(self, image, size):
width, height = image.size
if width < height:
new_width = size
new_height = int(size * height / width)
else:
new_height = size
new_width = int(size * width / height)
image = image.resize((new_width, new_height))
left = (new_width - size) / 2
top = (new_height - size) / 2
right = (new_width + size) / 2
bottom = (new_height + size) / 2
image = image.crop((left, top, right, bottom))
return np.array(image, dtype=np.float32) / 127.5 - 1
def _read_process_vision(self, path, max_n_frames):
f = open_file(path, 'rb')
if path.endswith('.png') or path.endswith('.jpg'):
image = Image.open(f).convert('RGB')
vision = self._process_frame(image, 256)[None]
else:
vr = decord.VideoReader(f, ctx=decord.cpu(0))
duration = len(vr)
if duration <= max_n_frames:
frame_id_list = list(range(duration))
else:
frame_id_list = np.linspace(0, duration - 1, max_n_frames, dtype=int).tolist()
video = vr.get_batch(frame_id_list).asnumpy()
vision = np.stack([self._process_frame(Image.fromarray(frame), 256) for frame in video])
B = 1
encodings = []
for i in range(0, len(vision), 1):
v = vision[i:i + B]
if len(v) % B == 0:
n_pad = 0
else:
n_pad = B - len(v) % B
v = np.pad(v, ((n_pad, 0), (0, 0), (0, 0), (0, 0)))
enc = jax.device_get(self.vqgan.encode(v))[1].astype(int)
enc = enc[n_pad:]
for t in range(len(enc)):
encodings.extend(enc[t].reshape(-1).tolist())
if t == len(enc) - 1:
encodings.append(8193)
else:
encodings.append(8192)
return encodings
def construct_input(self, prompts, max_n_frames):
max_input_length = max_n_frames * self.n_tokens_per_frame + self.min_buffer_size
max_input_length = int(math.ceil(max_input_length / self.block_size) * self.block_size)
vision_start = self.tokenizer.encode('')
vision_end = self.tokenizer.encode('')
input_ids = np.zeros((len(prompts), max_input_length), dtype=int)
vision_masks = np.zeros((len(prompts), max_input_length), dtype=bool)
attention_mask = np.zeros((len(prompts), max_input_length), dtype=int)
for i, prompt in enumerate(tqdm(prompts)):
vision = self._read_process_vision(prompt['input_path'], max_n_frames)
text_1 = self.tokenizer.encode(f"You are a helpful assistant. USER: {prompt['question']}\n")
tail = self.tokenizer.encode(" ASSISTANT:")
tokens, vm = [], []
tokens.extend(text_1)
vm.extend([False] * len(text_1))
tokens.extend(vision_start)
vm.extend([False] * len(vision_start))
tokens.extend(vision)
vm.extend([True] * len(vision))
tokens.extend(vision_end)
vm.extend([False] * len(vision_end))
tokens.extend(tail)
vm.extend([False] * len(tail))
assert len(tokens) < max_input_length, (len(tokens), max_input_length)
assert len(tokens) == len(vm)
input_ids[i, -len(tokens):] = tokens
vision_masks[i, -len(tokens):] = vm
attention_mask[i, -len(tokens):] = 1
return {
'input_ids': input_ids,
'vision_masks': vision_masks,
'attention_mask': attention_mask
}
def _load_model(self):
if FLAGS.load_llama_config != '':
llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config)
updates = VideoLLaMAConfig(**FLAGS.llama)
llama_config.update(dict(
scan_attention=updates.scan_attention,
scan_mlp=updates.scan_mlp,
scan_query_chunk_size=updates.scan_query_chunk_size,
scan_key_chunk_size=updates.scan_key_chunk_size,
scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
scan_layers=updates.scan_layers,
param_scan_axis=updates.param_scan_axis,
))
else:
llama_config = VideoLLaMAConfig(**FLAGS.llama)
if FLAGS.update_llama_config != '':
llama_config.update(dict(eval(FLAGS.update_llama_config)))
llama_config.update(dict(
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
))
llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))
self.config = llama_config
self.model = FlaxVideoLLaMAForCausalLM(
llama_config,
input_shape=(512, self.block_size),
seed=FLAGS.seed,
_do_init=False,
dtype=get_float_dtype_by_name(FLAGS.dtype),
)
with jax.default_device(jax.devices("cpu")[0]):
_, self.params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
)
self.model_ps = match_partition_rules(
VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params
)
shard_fns, _ = make_shard_and_gather_fns(
self.model_ps, get_float_dtype_by_name(FLAGS.dtype)
)
with self.mesh:
self.params = tree_apply(shard_fns, self.params)
@cached_property
def _forward_generate(self):
def fn(params, rng, batch):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
rng_generator = JaxRNG(rng)
output = self.model.generate(
batch['input_ids'],
vision_masks=batch['vision_masks'],
attention_mask=batch['attention_mask'],
params=params['params'],
prng_key=rng_generator(),
generation_config=GenerationConfig(
max_new_tokens=self.block_size,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
temperature=FLAGS.temperature,
do_sample=True,
)
).sequences[:, batch['input_ids'].shape[1]:]
return output, rng_generator()
return pjit(
fn,
in_shardings=(self.model_ps, PS(), PS()),
out_shardings=(PS(), PS())
)
def __call__(self, prompts, max_n_frames):
batch = self.construct_input(prompts, max_n_frames)
with self.mesh:
output, self.sharded_rng = self._forward_generate(
self.params, self.sharded_rng, batch
)
output = jax.device_get(output)
output_text = []
for text in list(self.tokenizer.batch_decode(output, skip_special_tokens=True)):
if self.tokenizer.eos_token in text:
text = text.split(self.tokenizer.eos_token, maxsplit=1)[0]
output_text.append(text)
return output_text
def main(argv):
assert FLAGS.prompt != ''
assert FLAGS.input_file != ''
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
set_random_seed(FLAGS.seed)
prompts = [{'input_path': FLAGS.input_file, 'question': FLAGS.prompt}]
sampler = Sampler()
output = sampler(prompts, FLAGS.max_n_frames)[0]
print(f"Question: {FLAGS.prompt}\nAnswer: {output}")
if __name__ == "__main__":
run(main)
================================================
FILE: lwm/vision_generation.py
================================================
from absl.app import run
from tqdm import tqdm
import imageio
import numpy as np
from PIL import Image
from transformers import GenerationConfig, AutoTokenizer
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
from tux import (
define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,
set_random_seed, get_float_dtype_by_name, JaxRNG,
match_partition_rules, make_shard_and_gather_fns,
with_sharding_constraint, tree_apply, next_rng
)
from lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM
from lwm.vqgan import VQGAN
FLAGS, FLAGS_DEF = define_flags_with_default(
prompt='Fireworks over the city',
output_file='',
temperature_image=1.0,
temperature_video=1.0,
top_k_image=8192,
top_k_video=100,
cfg_scale_image=1.0,
cfg_scale_video=1.0,
vqgan_checkpoint='',
n_frames=1,
seed=1234,
mesh_dim='1,-1,1,1',
dtype='fp32',
load_llama_config='',
update_llama_config='',
load_checkpoint='',
tokenizer='LargeWorldModel/LWM-Text-1M',
llama=VideoLLaMAConfig.get_default_config(),
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
assert FLAGS.output_file != ''
if FLAGS.output_file.endswith('mp4'):
assert FLAGS.n_frames > 1
elif FLAGS.output_file.endswith('png') or FLAGS.output_file.endswith('jpg'):
assert FLAGS.n_frames == 1
else:
raise ValueError(f"Unsupported output file extension: {FLAGS.output_file}")
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
set_random_seed(FLAGS.seed)
tokens_per_frame = 257
vqgan = VQGAN(FLAGS.vqgan_checkpoint, replicate=False)
mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer)
prefix_tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer, truncation_side='left', padding_side='left')
if FLAGS.load_llama_config != '':
llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config)
updates = VideoLLaMAConfig(**FLAGS.llama)
llama_config.update(dict(
scan_attention=updates.scan_attention,
scan_mlp=updates.scan_mlp,
scan_query_chunk_size=updates.scan_query_chunk_size,
scan_key_chunk_size=updates.scan_key_chunk_size,
scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
scan_layers=updates.scan_layers,
param_scan_axis=updates.param_scan_axis,
))
else:
llama_config = VideoLLaMAConfig(**FLAGS.llama)
if FLAGS.update_llama_config != '':
llama_config.update(dict(eval(FLAGS.update_llama_config)))
llama_config.update(dict(
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
))
llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))
with jax.default_device(jax.devices("cpu")[0]):
_, params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
)
model = FlaxVideoLLaMAForCausalLM(
llama_config,
input_shape=(512, 8192),
seed=FLAGS.seed,
_do_init=False,
dtype=get_float_dtype_by_name(FLAGS.dtype),
)
model_ps = match_partition_rules(
VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), params
)
shard_fns, _ = make_shard_and_gather_fns(
model_ps, get_float_dtype_by_name(FLAGS.dtype)
)
with mesh:
params = tree_apply(shard_fns, params)
def _forward_generate(params, rng, batch, n_tokens, cfg_scale, top_k, temperature):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
cfg_scales = jnp.ones((batch['input_ids'].shape[0] // 2,), dtype=jnp.float32) * cfg_scale
cfg_scales = with_sharding_constraint(cfg_scales, PS(('dp', 'fsdp')))
rng_generator = JaxRNG(rng)
output = model.generate_vision(
batch['input_ids'],
cfg_scales,
attention_mask=batch['attention_mask'],
vision_masks=batch['vision_masks'],
params=params['params'],
prng_key=rng_generator(),
generation_config=GenerationConfig(
max_new_tokens=n_tokens,
min_new_tokens=n_tokens,
pad_token_id=tokenizer.pad_token_id,
temperature=temperature,
do_sample=True,
top_k=top_k,
)
).sequences[:, batch['input_ids'].shape[1]:]
return output, rng_generator()
_sharded_forward_generate = pjit(
_forward_generate,
in_shardings=(model_ps, PS(), PS()),
out_shardings=(PS(), PS()),
static_argnums=(3, 4, 5, 6)
)
# Generate an image or first frame (for video)
def generate_first_frame(prompts, max_input_length):
nonlocal sharded_rng
uncond_prompts = [""] * len(prompts)
prompts = prompts + uncond_prompts
inputs = prefix_tokenizer(
prompts,
padding='max_length',
truncation=True,
max_length=max_input_length,
return_tensors='np'
)
batch = dict(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
vision_masks=np.zeros(inputs.input_ids.shape, dtype=bool),
)
with mesh:
output, sharded_rng = _sharded_forward_generate(
params, sharded_rng, batch,
tokens_per_frame, FLAGS.cfg_scale_image,
FLAGS.top_k_image, FLAGS.temperature_image
)
output = jax.device_get(output)
output = np.split(output, 2, axis=0)[0]
output = output.reshape(len(prompts) // 2, tokens_per_frame)
image = vqgan.decode(output[:, :-1].reshape(-1, 16, 16))
image = ((jax.device_get(image) + 1) * 127.5).astype(np.uint8)
return output, image
sharded_rng = next_rng()
prompts = [FLAGS.prompt]
entries = []
for prompt in prompts:
entries.append({
'caption': prompt,
'prompt': f"You are a helpful assistant. USER: Generate an image of {prompt} ASSISTANT: ",
})
B = 1
images, image_encodings = [], []
for i in tqdm(list(range(0, len(entries), B))):
entries_i = entries[i:i + B]
prompts = [entry['prompt'] for entry in entries_i]
img_enc, img = generate_first_frame(prompts, max_input_length=128)
image_encodings.extend(img_enc)
images.extend(img)
if FLAGS.n_frames == 1:
image = images[0]
Image.fromarray(image).save(FLAGS.output_file)
return
# Generate the rest of the video
def generate_video_pred(prompts, images, max_input_length):
nonlocal sharded_rng
images = np.concatenate([images, images], axis=0)
uncond_prompts = [""] * len(prompts)
prompts = prompts + uncond_prompts
inputs = prefix_tokenizer(
prompts,
padding='max_length',
truncation=True,
max_length=max_input_length,
return_tensors='np'
)
batch = dict(
input_ids=np.concatenate([inputs.input_ids, images], axis=1),
attention_mask=np.concatenate([inputs.attention_mask, np.ones(images.shape, dtype=inputs.attention_mask.dtype)], axis=1),
vision_masks=np.concatenate([
np.zeros(inputs.input_ids.shape, dtype=bool),
np.ones(images.shape, dtype=bool)
], axis=1),
)
with mesh:
output, sharded_rng = _sharded_forward_generate(
params, sharded_rng, batch,
(FLAGS.n_frames - 1) * tokens_per_frame, FLAGS.cfg_scale_video,
FLAGS.top_k_video, FLAGS.temperature_video
)
output = jax.device_get(output)
output = np.split(output, 2, axis=0)[0]
output = output.reshape(len(prompts) // 2, FLAGS.n_frames - 1, tokens_per_frame)
output = np.concatenate([images[:len(prompts) // 2, None], output], axis=1)
output = output[:, :, :-1].reshape(-1, FLAGS.n_frames, 16, 16)
vision = []
for v in output:
v = vqgan.decode(v)
v = ((jax.device_get(v) + 1) * 127.5).astype(np.uint8)
vision.append(v)
return vision
new_entries = []
for img_enc, entry in zip(image_encodings, entries):
new_entries.append({
'caption': entry['caption'],
'prompt': f"You are a helpful assistant. USER: Generate a video of {entry['caption']} ASSISTANT: ",
'image': np.array(img_enc, dtype=np.int32),
})
entries = new_entries
B = 1
videos = []
for i in tqdm(list(range(0, len(entries), B))):
entries_i = entries[i:i + B]
prompts = [entry['prompt'] for entry in entries_i]
images = np.array([entry['image'] for entry in entries_i], dtype=np.int32)
videos.extend(generate_video_pred(prompts, images, max_input_length=128))
video = videos[0]
writer = imageio.get_writer(FLAGS.output_file, fps=4)
for frame in video:
writer.append_data(frame)
writer.close()
print('done')
if __name__ == "__main__":
run(main)
================================================
FILE: lwm/vision_llama.py
================================================
from typing import Any, Dict, List, Optional, Tuple, Union
import json
import warnings
import copy
import jax
import jax.numpy as jnp
from jax import lax
from jax.sharding import PartitionSpec as PS
import flax.linen as nn
from flax.core.frozen_dict import unfreeze, freeze
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
from transformers.generation.flax_utils import SampleState, FlaxLogitsProcessorList, FlaxSampleOutput, logger
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers import GenerationConfig
from tux import load_pickle, open_file
from lwm.llama import LLaMAConfig, LLAMA_STANDARD_CONFIGS, FlaxLLaMABlockCollection, RMSNorm
VIDEO_LLAMA_STANDARD_CONFIGS = LLAMA_STANDARD_CONFIGS
class VideoLLaMAConfig(LLaMAConfig):
model_type = "video_llama"
def __init__(self, vision_vocab_size=8448, tie_vision_embeddings=False, sample_mode='all', **kwargs):
super().__init__(**kwargs)
self.vision_vocab_size = vision_vocab_size # 8192 + 256
self.tie_vision_embeddings = tie_vision_embeddings
self.sample_mode = sample_mode
@staticmethod
def get_partition_rules(scan_layers=False, scan_axis=0):
"""Parition rules are orderd, so that the beginning rules match first."""
if scan_layers:
if scan_axis == 0:
return (
# embeddings
("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
("transformer/vte/embedding", PS("tp", ("fsdp", "sp"))),
# atention
("attention/(wq|wk|wv)/kernel", PS(None, ("fsdp", "sp"), "tp")),
("attention/wo/kernel", PS(None, "tp", ("fsdp", "sp"))),
# mlp
("feed_forward/w1/kernel", PS(None, ("fsdp", "sp"), "tp")),
("feed_forward/w2/kernel", PS(None, "tp", ("fsdp", "sp"))),
("feed_forward/w3/kernel", PS(None, ("fsdp", "sp"), "tp")),
# layer norms
("attention_norm/kernel", PS(None, None)),
("ffn_norm/kernel", PS(None, None)),
# output head
("transformer/ln_f/kernel", PS(None)),
("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
("vision_head/kernel", PS(("fsdp", "sp"), "tp")),
('.*', PS(None)),
)
elif scan_axis == 1:
return (
# embeddings
("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
("transformer/vte/embedding", PS("tp", ("fsdp", "sp"))),
# atention
("attention/(wq|wk|wv)/kernel", PS(("fsdp", "sp"), None, "tp")),
("attention/wo/kernel", PS("tp", None, ("fsdp", "sp"))),
# mlp
("feed_forward/w1/kernel", PS(("fsdp", "sp"), None, "tp")),
("feed_forward/w2/kernel", PS("tp", None, ("fsdp", "sp"))),
("feed_forward/w3/kernel", PS(("fsdp", "sp"), None, "tp")),
# layer norms
("attention_norm/kernel", PS(None, None)),
("ffn_norm/kernel", PS(None, None)),
# output head
("transformer/ln_f/kernel", PS(None)),
("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
("vision_head/kernel", PS(("fsdp", "sp"), "tp")),
('.*', PS(None)),
)
else:
raise ValueError(f"Invalid scan_axis {scan_axis}")
else:
return (
# embeddings
("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))),
("transformer/vte/embedding", PS("tp", ("fsdp", "sp"))),
# atention
("attention/(wq|wk|wv)/kernel", PS(("fsdp", "sp"), "tp")),
("attention/wo/kernel", PS("tp", ("fsdp", "sp"))),
# mlp
("feed_forward/w1/kernel", PS(("fsdp", "sp"), "tp")),
("feed_forward/w2/kernel", PS("tp", ("fsdp", "sp"))),
("feed_forward/w3/kernel", PS(("fsdp", "sp"), "tp")),
# layer norms
("attention_norm/kernel", PS(None)),
("ffn_norm/kernel", PS(None)),
# output head
("transformer/ln_f/kernel", PS(None)),
("lm_head/kernel", PS(("fsdp", "sp"), "tp")),
("vision_head/kernel", PS(("fsdp", "sp"), "tp")),
('.*', PS(None)),
)
@classmethod
def load_config(cls, path):
if path in VIDEO_LLAMA_STANDARD_CONFIGS:
return cls.from_dict(VIDEO_LLAMA_STANDARD_CONFIGS[path])
load_type, load_path = path.split('::', 1)
if load_type == 'pickle':
return cls.from_dict(load_pickle(load_path)['llama_config'])
elif load_type == 'json':
with open_file(load_path, 'r') as fin:
raw_config = fin.read()
return cls.from_dict(json.loads(raw_config))
else:
raise ValueError(f'Unsupported load config type: {load_type}')
class FlaxVideoLLaMAPreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = VideoLLaMAConfig
base_model_prefix = "transformer"
module_class: nn.Module = None
def __init__(
self,
config: VideoLLaMAConfig,
input_shape: Tuple = (4, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_cache(self, batch_size, max_length):
# init input variables to retrieve cache
input_ids = jnp.ones((batch_size, max_length))
attention_mask = jnp.ones_like(input_ids)
segment_ids = jnp.zeros_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
vision_masks = jnp.ones((batch_size, max_length), dtype=bool)
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False, init_cache=True
)
return init_variables["cache"]
def init_weights(self, rng, input_shape, params=None):
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
vision_masks = jnp.ones(input_ids.shape, dtype=bool)
segment_ids = jnp.zeros_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(rngs, input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward("")
def __call__(
self,
input_ids,
vision_masks,
attention_mask=None,
segment_ids=None,
position_ids=None,
params: dict = None,
past_key_values: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
batch_size, sequence_length = input_ids.shape
if position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length))
if segment_ids is None:
segment_ids = jnp.zeros((batch_size, sequence_length))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"),
jnp.array(vision_masks, dtype="f4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(segment_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
False,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
mutable=mutable,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past_key_values = outputs
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past_key_values = outputs
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
return outputs
class FlaxVideoLLaMAModule(nn.Module):
config: VideoLLaMAConfig
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype=jnp.float32
precision: Optional[Union[jax.lax.Precision, str]]=None
def setup(self):
self.embed_dim = self.config.hidden_size
self.vte = nn.Embed(
self.config.vision_vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
param_dtype=self.param_dtype,
)
self.wte = nn.Embed(
self.config.vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
param_dtype=self.param_dtype,
)
self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
self.h = FlaxLLaMABlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision)
self.ln_f = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype)
def __call__(
self,
input_ids,
vision_masks,
attention_mask,
segment_ids,
position_ids,
deterministic=True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
input_ids = input_ids.astype("i4")
if input_ids.shape[1] == 1:
if self.config.sample_mode == 'text':
input_embeds = self.wte(input_ids)
elif self.config.sample_mode == 'vision':
input_embeds = self.vte(input_ids)
elif self.config.sample_mode == 'all':
raise NotImplementedError
else:
raise ValueError(f"Invalid sample_mode: {self.config.sample_mode}")
else:
input_text_embeds = self.wte(jnp.where(vision_masks, 0, input_ids))
input_vision_embeds = self.vte(jnp.where(vision_masks, input_ids, 0))
vision_masks = vision_masks[..., None].astype("f4") # 1 is vision, 0 is text
input_embeds = input_text_embeds * (1 - vision_masks) + input_vision_embeds * vision_masks
hidden_states = self.dropout(input_embeds, deterministic=deterministic)
outputs = self.h(
hidden_states,
attention_mask,
segment_ids,
position_ids=position_ids,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = outputs[1] + (hidden_states,)
outputs = (hidden_states, all_hidden_states) + outputs[2:]
else:
outputs = (hidden_states,) + outputs[1:]
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=outputs[1],
attentions=outputs[-1],
)
class FlaxVideoLLaMAForCausalLMModule(nn.Module):
config: VideoLLaMAConfig
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype=jnp.float32
precision: Optional[Union[jax.lax.Precision, str]]=None
def setup(self):
self.transformer = FlaxVideoLLaMAModule(self.config, dtype=self.dtype)
self.vision_head = nn.Dense(
self.config.vision_vocab_size,
dtype=self.dtype,
param_dtype=self.param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
precision=self.precision,
)
self.lm_head = nn.Dense(
self.config.vocab_size,
dtype=self.dtype,
param_dtype=self.param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
precision=self.precision,
)
def __call__(
self,
input_ids,
vision_masks,
attention_mask=None,
segment_ids=None,
position_ids=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
batch_size, seq_length = input_ids.shape
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if segment_ids is None:
segment_ids = jnp.zeros_like(input_ids)
if position_ids is None:
position_ids = jnp.broadcast_to(
jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
(batch_size, seq_length)
)
outputs = self.transformer(
input_ids,
vision_masks,
attention_mask,
segment_ids,
position_ids,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.tie_vision_embeddings:
shared_kernel = self.transformer.variables["params"]["vte"]["embedding"].T
vision_logits = self.vision_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
else:
vision_logits = self.vision_head(hidden_states)
if self.config.tie_word_embeddings:
shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
else:
lm_logits = self.lm_head(hidden_states)
if self.config.sample_mode == 'all':
if not return_dict:
return (vision_logits, lm_logits,) + outputs[1:]
return FlaxCausalLMOutput(logits=(vision_logits, lm_logits), hidden_states=outputs.hidden_states, attentions=outputs.attentions)
elif self.config.sample_mode == 'vision':
if not return_dict:
return (vision_logits,) + outputs[1:]
return FlaxCausalLMOutput(logits=vision_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
elif self.config.sample_mode == 'text':
if not return_dict:
return (lm_logits,) + outputs[1:]
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
else:
raise ValueError(f"Invalid sample_mode: {self.config.sample_mode}")
@add_start_docstrings("", "")
class FlaxVideoLLaMAForCausalLM(FlaxVideoLLaMAPreTrainedModel):
module_class = FlaxVideoLLaMAForCausalLMModule
def prepare_inputs_for_generation(
self, input_ids, max_length, attention_mask: Optional[jax.Array] = None, vision_masks = None
):
# initializing the cache
batch_size, seq_length = input_ids.shape
past_key_values = self.init_cache(batch_size, max_length)
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
position_ids = attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
"position_ids": position_ids,
"vision_masks": vision_masks
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
return {
"past_key_values": model_outputs.past_key_values,
"position_ids": model_kwargs["position_ids"][:, -1:] + 1,
"attention_mask": model_kwargs["attention_mask"],
"vision_masks": model_kwargs["vision_masks"]
}
def _sample_vision(
self,
input_ids: None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
prng_key: Optional[jnp.ndarray] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
logits_warper: Optional[FlaxLogitsProcessorList] = None,
cfg_scales: jnp.ndarray = 1.0,
trace: bool = True,
params: Optional[Dict[str, jnp.ndarray]] = None,
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
):
# init values
max_length = max_length if max_length is not None else self.generation_config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
batch_size, cur_len = input_ids.shape
initial_len = cur_len
eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
cur_len = jnp.array(cur_len)
# per batch-item holding current token in loop.
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
# per batch-item state bit indicating if sentence has finished.
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
model = self.decode if self.config.is_encoder_decoder else self
# initialize model specific kwargs
model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
# initialize state
state = SampleState(
cur_len=cur_len,
sequences=sequences,
running_token=input_ids,
is_sent_finished=is_sent_finished,
prng_key=prng_key,
model_kwargs=model_kwargs,
)
def sample_search_cond_fn(state):
"""state termination condition fn."""
has_reached_max_length = state.cur_len == max_length
all_sequence_finished = jnp.all(state.is_sent_finished)
finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
return ~finish_generation
def sample_search_body_fn(state):
"""state update fn."""
prng_key, prng_key_next = jax.random.split(state.prng_key)
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
logits = model_outputs.logits[:, -1]
cond_logits, uncond_logits = jnp.split(logits, 2, axis=0)
logits = uncond_logits + cfg_scales[:, None] * (cond_logits - uncond_logits)
# apply min_length, ...
logits = logits_processor(state.sequences, logits, state.cur_len)
# apply top_p, top_k, temperature
logits = logits_warper(logits, logits, state.cur_len)
next_token = jax.random.categorical(prng_key, logits, axis=-1)
next_token = jax.lax.cond(
(state.cur_len - initial_len + 1) % 257 == 0,
lambda: jnp.full_like(next_token, 8192),
lambda: next_token
)
next_token = jnp.concatenate([next_token, next_token], axis=0)
#next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
next_token = next_token[:, None]
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
return SampleState(
cur_len=state.cur_len + 1,
sequences=next_sequences,
running_token=next_token,
is_sent_finished=next_is_sent_finished,
model_kwargs=next_model_kwargs,
prng_key=prng_key_next,
)
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
if input_ids.shape[1] > 1:
state = sample_search_body_fn(state)
if not trace:
state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)
else:
state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
return FlaxSampleOutput(sequences=state.sequences)
def generate_vision(
self,
input_ids: jnp.ndarray,
cfg_scales: jnp.ndarray,
generation_config: Optional[GenerationConfig] = None,
prng_key: Optional[jnp.ndarray] = None,
trace: bool = True,
params: Optional[Dict[str, jnp.ndarray]] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
**kwargs,
):
# Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# two conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
self.generation_config
):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
)
self.generation_config = new_generation_config
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList()
# set init values
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask") is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder:
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
# decoder-only models should use left-padding for generation (can't be checked with `trace=True`)
if not self.config.is_encoder_decoder and not trace:
if (
generation_config.pad_token_id is not None
and jnp.sum(input_ids[:, -1] == generation_config.pad_token_id) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
batch_size = input_ids.shape[0]
if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs
if model_kwargs.get("encoder_outputs") is None:
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
# prepare decoder_input_ids for generation
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
model_kwargs=model_kwargs,
)
# Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
# 20 is the default max_length of the generation config
warnings.warn(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
"to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
if not has_default_max_length and generation_config.max_length is not None:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger than"
f" the maximum length ({generation_config.max_length})"
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing`max_new_tokens`."
)
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
logits_processor=logits_processor,
)
if not generation_config.do_sample and generation_config.num_beams == 1:
raise NotImplementedError
elif generation_config.do_sample and generation_config.num_beams == 1:
logits_warper = self._get_logits_warper(generation_config=generation_config)
return self._sample_vision(
input_ids,
generation_config.max_length,
generation_config.pad_token_id,
generation_config.eos_token_id,
prng_key,
logits_warper=logits_warper,
logits_processor=logits_processor,
cfg_scales=cfg_scales,
trace=trace,
params=params,
model_kwargs=model_kwargs,
)
elif not generation_config.do_sample and generation_config.num_beams > 1:
raise NotImplementedError
else:
raise NotImplementedError("`Beam sampling is currently not implemented.")
================================================
FILE: lwm/vqgan.py
================================================
from typing import Optional
from functools import cached_property, partial
import pickle
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax import jax_utils
from transformers.configuration_utils import PretrainedConfig
from ml_collections import ConfigDict
from tux import function_args_to_config, open_file
class VQGAN:
def __init__(self, vqgan_checkpoint, replicate=False):
assert vqgan_checkpoint != ''
self.replicate = replicate
self.config = VQGANConfig.get_default_config()
self.params = pickle.load(open_file(vqgan_checkpoint, 'rb'))
if replicate:
self.params = jax_utils.replicate(self.params)
else:
self.params = jax.jit(lambda x: x)(self.params)
self.model = VQGANModel(self.config)
def _wrap_fn(self, fn):
if self.replicate:
return jax.pmap(fn, devices=jax.local_devices())
else:
return jax.jit(fn)
@cached_property
def _encode(self):
def fn(pixel_values, params):
return self.model.apply(
{'params': params},
pixel_values,
method=self.model.encode
)
return partial(self._wrap_fn(fn), params=self.params)
@cached_property
def _decode(self):
def fn(encoding, params):
return self.model.apply(
{'params': params},
encoding,
method=self.model.decode
)
return partial(self._wrap_fn(fn), params=self.params)
def encode(self, pixel_values):
return self._encode(pixel_values)
def decode(self, encoding):
return self._decode(encoding)
class VQGANConfig(PretrainedConfig):
model_type = "vqgan"
def __init__(
self,
resolution=256,
num_channels=3,
hidden_channels=128,
channel_mult=(1, 2, 2, 4, 6),
num_res_blocks=2,
attn_resolutions=(),
no_attn_mid_block=True,
z_channels=64,
num_embeddings=8192,
quantized_embed_dim=64,
dropout=0.0,
resample_with_conv=True,
commitment_cost=0.25
):
self.resolution = resolution
self.num_channels = num_channels
self.hidden_channels = hidden_channels
self.channel_mult = channel_mult
self.num_res_blocks = num_res_blocks
self.attn_resolutions = attn_resolutions
self.no_attn_mid_block = no_attn_mid_block
self.z_channels = z_channels
self.num_embeddings = num_embeddings
self.quantized_embed_dim = quantized_embed_dim
self.dropout = dropout
self.resample_with_conv = resample_with_conv
self.commitment_cost = commitment_cost
@classmethod
def get_default_config(cls, updates=None):
config = function_args_to_config(cls.__init__)
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
config.num_resolutions = len(config.channel_mult)
return config
@classmethod
def load_config(cls, path):
return cls.get_default_config(cls)
class VQGANModel(nn.Module):
config: VQGANConfig
def setup(self):
self.encoder = Encoder(self.config)
self.decoder = Decoder(self.config)
self.quantize = VectorQuantizer(
self.config.num_embeddings, self.config.quantized_embed_dim
)
self.quant_conv = nn.Conv(self.config.quantized_embed_dim, [1, 1])
self.post_quant_conv = nn.Conv(self.config.z_channels, [1, 1])
def encode(self, pixel_values):
T = None
if len(pixel_values.shape) == 5: # video
T = pixel_values.shape[1]
pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:])
hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states)
quantized_states, codebook_indices = self.quantize(hidden_states)
if T is not None:
quantized_states = quantized_states.reshape(-1, T, *quantized_states.shape[1:])
codebook_indices = codebook_indices.reshape(-1, T, *codebook_indices.shape[1:])
return quantized_states, codebook_indices
def decode(self, encoding, is_codebook_indices=True):
if is_codebook_indices:
encoding = self.quantize(None, encoding)
T = None
if len(encoding.shape) == 5:
T = encoding.shape[1]
encoding = encoding.reshape(-1, *encoding.shape[2:])
hidden_states = self.post_quant_conv(encoding)
reconstructed_pixel_values = self.decoder(hidden_states)
if T is not None:
reconstructed_pixel_values = reconstructed_pixel_values.reshape(-1, T, *reconstructed_pixel_values.shape[1:])
return jnp.clip(reconstructed_pixel_values, -1, 1)
def __call__(self, pixel_values):
encoding = self.encode(pixel_values)[1]
recon = self.decode(encoding)
return recon
class Encoder(nn.Module):
config: VQGANConfig
@nn.compact
def __call__(self, pixel_values):
assert pixel_values.shape[1] == pixel_values.shape[2] == self.config.resolution, pixel_values.shape
hidden_states = nn.Conv(self.config.hidden_channels, [3, 3])(pixel_values)
for i_level in range(self.config.num_resolutions):
hidden_states = DownsamplingBlock(self.config, i_level)(hidden_states)
hidden_states = MidBlock(
self.config, self.config.no_attn_mid_block, self.config.dropout
)(hidden_states)
hidden_states = nn.GroupNorm()(hidden_states)
hidden_states = nn.silu(hidden_states)
hidden_states = nn.Conv(self.config.z_channels, [3, 3])(hidden_states)
return hidden_states
class Decoder(nn.Module):
config: VQGANConfig
@nn.compact
def __call__(self, hidden_states):
hidden_states = nn.Conv(
self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1],
[3, 3]
)(hidden_states)
hidden_states = MidBlock(
self.config, self.config.no_attn_mid_block, self.config.dropout
)(hidden_states)
for i_level in reversed(range(self.config.num_resolutions)):
hidden_states = UpsamplingBlock(self.config, i_level)(hidden_states)
hidden_states = nn.GroupNorm()(hidden_states)
hidden_states = nn.silu(hidden_states)
hidden_states = nn.Conv(self.config.num_channels, [3, 3])(hidden_states)
return hidden_states
class VectorQuantizer(nn.Module):
n_e: int
e_dim: int
@nn.compact
def __call__(self, z, encoding_indices=None):
def quantize(encoding_indices):
w = jax.device_put(embeddings)
return w[(encoding_indices,)]
embeddings = self.param(
'embeddings',
lambda rng, shape, dtype: jax.random.uniform(
rng, shape, dtype, minval=-1.0 / self.n_e, maxval=1.0 / self.n_e
),
[self.n_e, self.e_dim], jnp.float32
)
if encoding_indices is not None:
return quantize(encoding_indices)
z_flattened = z.reshape(-1, z.shape[-1])
d = jnp.sum(z_flattened ** 2, axis=1, keepdims=True) + \
jnp.sum(embeddings.T ** 2, axis=0, keepdims=True) - \
2 * jnp.einsum('bd,nd->bn', z_flattened, embeddings)
min_encoding_indices = jnp.argmin(d, axis=1)
z_q = quantize(min_encoding_indices)
z_q = jnp.reshape(z_q, z.shape)
z_q = z + jax.lax.stop_gradient(z_q - z)
encodings_one_hot = jax.nn.one_hot(min_encoding_indices, num_classes=self.n_e)
assert len(encodings_one_hot.shape) == 2
min_encoding_indices = jnp.reshape(min_encoding_indices, z.shape[:-1])
return z_q, min_encoding_indices
class DownsamplingBlock(nn.Module):
config: VQGANConfig
block_idx: int
@nn.compact
def __call__(self, hidden_states):
block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
for _ in range(self.config.num_res_blocks):
hidden_states = ResnetBlock(
block_out, dropout_prob=self.config.dropout
)(hidden_states)
if hidden_states.shape[1] in self.config.attn_resolutions:
hidden_states = AttnBlock()(hidden_states)
if self.block_idx != self.config.num_resolutions - 1:
hidden_states = Downsample(self.config.resample_with_conv)(hidden_states)
return hidden_states
class ResnetBlock(nn.Module):
out_channels: Optional[int] = None
use_conv_shortcut: bool = False
dropout_prob: float = 0.0
@nn.compact
def __call__(self, hidden_states):
out_channels = self.out_channels or hidden_states.shape[-1]
residual = hidden_states
hidden_states = nn.GroupNorm()(hidden_states)
hidden_states = nn.silu(hidden_states)
hidden_states = nn.Conv(out_channels, [3, 3])(hidden_states)
hidden_states = nn.GroupNorm()(hidden_states)
hidden_states = nn.silu(hidden_states)
hidden_states = nn.Dropout(self.dropout_prob, deterministic=True)(hidden_states)
hidden_states = nn.Conv(out_channels, [3, 3])(hidden_states)
if out_channels != residual.shape[-1]:
if self.use_conv_shortcut:
residual = nn.Conv(out_channels, [3, 3])(residual)
else:
residual = nn.Conv(out_channels, [1, 1])(residual)
return hidden_states + residual
class AttnBlock(nn.Module):
@nn.compact
def __call__(self, hidden_states):
residual = hidden_states
hidden_states = nn.GroupNorm()(hidden_states)
query = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states)
key = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states)
value = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states)
query, key, value = map(
lambda x: x.reshape(x.shape[0], -1, x.shape[-1]),
[query, key, value]
)
attn_weights = jnp.einsum("bqd,bkd->bqk", query, key)
attn_weights *= hidden_states.shape[-1] ** -0.5
attn_weights = jax.nn.softmax(attn_weights, axis=-1)
hidden_states = jnp.einsum("bqk,bkd->bqd", attn_weights, value)
hidden_states = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states)
return hidden_states + residual
class Downsample(nn.Module):
with_conv: bool
@nn.compact
def __call__(self, hidden_states):
if self.with_conv:
hidden_states = jnp.pad(
hidden_states,
[(0, 0), (0, 1), (0, 1), (0, 0)]
)
hidden_states = nn.Conv(
hidden_states.shape[-1], [3, 3],
strides=[2, 2],
padding="VALID"
)(hidden_states)
else:
hidden_states = nn.avg_pool(hidden_states, [2, 2], [2, 2])
return hidden_states
class Upsample(nn.Module):
with_conv: bool
@nn.compact
def __call__(self, hidden_states):
B, H, W, C = hidden_states.shape
hidden_states = jax.image.resize(
hidden_states,
(B, H * 2, W * 2, C),
method="nearest"
)
if self.with_conv:
hidden_states = nn.Conv(hidden_states.shape[-1], [3, 3])(hidden_states)
return hidden_states
class UpsamplingBlock(nn.Module):
config: VQGANConfig
block_idx: int
@nn.compact
def __call__(self, hidden_states):
block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
for _ in range(self.config.num_res_blocks + 1):
hidden_states = ResnetBlock(
block_out, dropout_prob=self.config.dropout
)(hidden_states)
if hidden_states.shape[1] in self.config.attn_resolutions:
hidden_states = AttnBlock()(hidden_states)
if self.block_idx != 0:
hidden_states = Upsample(self.config.resample_with_conv)(hidden_states)
return hidden_states
class MidBlock(nn.Module):
config: VQGANConfig
no_attn: bool
dropout: float
@nn.compact
def __call__(self, hidden_states):
hidden_states = ResnetBlock(dropout_prob=self.dropout)(hidden_states)
if not self.no_attn:
hidden_states = AttnBlock()(hidden_states)
hidden_states = ResnetBlock(dropout_prob=self.dropout)(hidden_states)
return hidden_states
================================================
FILE: scripts/create_needle_data.py
================================================
import os
import argparse
import json
from tqdm import tqdm
from datasets import load_dataset
parser = argparse.ArgumentParser()
parser.add_argument("--output_path", type=str, default="data/pg19.jsonl")
args = parser.parse_args()
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
dset = load_dataset("pg19")["train"]
with open(args.output_path, "w") as f:
for elem in tqdm(dset):
data = {"text": elem["text"]}
f.write(f"{json.dumps(data)}\n")
================================================
FILE: scripts/eval_needle.py
================================================
from absl.app import run
import time
import json
import math
import os
from tqdm import tqdm
import random
from functools import cached_property
import numpy as np
import jax
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
import gcsfs
import tiktoken
from transformers import GenerationConfig, AutoTokenizer
from tux import (
define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,
set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng,
match_partition_rules, make_shard_and_gather_fns,
with_sharding_constraint, tree_apply, open_file
)
from lwm.llama import LLaMAConfig, FlaxLLaMAForCausalLM
FLAGS, FLAGS_DEF = define_flags_with_default(
haystack_file="",
max_tokens_per_batch=2000000,
output_file="results.json",
context_lengths_min=1000,
context_lengths_max=32000,
n_context_length_intervals=3,
n_document_depth_intervals=3,
n_rounds=2,
seed=1234,
mesh_dim='1,-1,1,1',
dtype='fp32',
load_llama_config='',
update_llama_config='',
load_checkpoint='',
tokenizer='LargeWorldModel/LWM-Text-1M',
checkpointer=StreamingCheckpointer.get_default_config(),
llama=LLaMAConfig.get_default_config(),
jax_distributed=JaxDistributedConfig.get_default_config(),
)
class LLMNeedleHaystackTester:
OURS_TEMPLATE = "You are a helpful assistant. USER: {context} {question} Don't give information outside the document or repeat your findings. Keep your response short and direct. ASSISTANT: "
RANDOM_NEEDLE_CITIES = [
'Chicago', 'Yangon', 'Antananarivo', 'Colombo', 'Almaty', 'Sydney', 'Chicago', 'Mexico City',
'Seattle', 'Lagos', 'Amsterdam', 'Belgrade', 'Cairo', 'Baghdad', 'Damascus', 'Kigali', 'Dakar',
'Dakar', 'Sofia', 'Kigali', 'Victoria', 'Tashkent', 'Mumbai', 'Barcelona', 'Almaty', 'Amman',
'Toronto', 'Bratislava', 'Johannesburg', 'Thimphu', 'Bangkok', 'Santiago', 'Cairo', 'San Francisco',
'Lagos', 'Amsterdam', 'Paris', 'Rabat', 'Santiago', 'Copenhagen', 'Madrid', 'Kigali',
'Ho Chi Minh City', 'Sarajevo', 'Delhi', 'Istanbul', 'Ho Chi Minh City', 'Khartoum', 'Helsinki',
'Doha', 'Istanbul', 'Kuala Lumpur', 'Budapest', 'Shanghai', 'Moscow', 'Los Angeles', 'Oslo',
'Johannesburg', 'Berlin', 'Bangalore', 'Tokyo', 'Melbourne', 'Barcelona', 'Chicago', 'Port Louis',
'Lisbon', 'Nairobi', 'Kampala', 'Lima', 'Maputo', 'Vancouver', 'Dubai', 'Khartoum', 'Jakarta',
'Madrid', 'Yerevan', 'Beirut', 'Athens', 'Chicago', 'Paris', 'Bucharest', 'Copenhagen', 'Brussels',
'Damascus', 'Seattle', 'Los Angeles', 'Yerevan', 'Victoria', 'Tunis', 'Astana', 'Seoul',
'Buenos Aires', 'Bangkok', 'Colombo', 'Brussels', 'Khartoum', 'Doha', 'San Francisco', 'Vienna', 'Jakarta'
]
def __init__(self,
needle="",
haystack_file="",
retrieval_question="What is the special magic {} number?",
results_version = 1,
rnd_number_digits = 7,
context_lengths_min = 1000,
context_lengths_max = 126000,
context_lengths_num_intervals = 10,
document_depth_percent_min = 0,
document_depth_percent_max = 100,
document_depth_percent_intervals = 10,
document_depth_percent_interval_type = "linear",
save_results = False,
final_context_length_buffer = 200,
print_ongoing_status = True):
needle="\nThe special magic {city} number is: {rnd_number}\n"
self.needle = needle
if not needle or not haystack_file or not retrieval_question:
raise ValueError("Needle, haystack, and retrieval_question must be provided.")
self.rnd_number_digits = rnd_number_digits
self.context_lengths_num_intervals = context_lengths_num_intervals
self.document_depth_percent_intervals = document_depth_percent_intervals
self.haystack_file = haystack_file
self.retrieval_question = retrieval_question
self.results_version = results_version
self.save_results = save_results
self.final_context_length_buffer = final_context_length_buffer
self.print_ongoing_status = print_ongoing_status
self.testing_results = []
self.context_lengths = np.round(np.linspace(context_lengths_min, context_lengths_max, num=context_lengths_num_intervals, endpoint=True)).astype(int)
if document_depth_percent_interval_type == 'linear':
self.document_depth_percents = np.round(np.linspace(document_depth_percent_min, document_depth_percent_max, num=document_depth_percent_intervals, endpoint=True)).astype(int)
elif document_depth_percent_interval_type == 'sigmoid':
self.document_depth_percents = [self.logistic(x) for x in np.linspace(document_depth_percent_min, document_depth_percent_max, document_depth_percent_intervals)]
else:
raise ValueError(f"Unsupported document_depth_percent_interval_type: {document_depth_percent_interval_type}")
self.model = Sampler()
self.enc = AutoTokenizer.from_pretrained(FLAGS.tokenizer)
self.enc_tiktoken = tiktoken.encoding_for_model("gpt-4-1106-preview")
def generate_random_number(self, num_digits):
lower_bound = 10**(num_digits - 1)
upper_bound = 10**num_digits - 1
return random.randint(lower_bound, upper_bound)
def logistic(self, x, L=100, x0=50, k=.1):
if x == 0:
return 0
if x == 100:
return 100
return np.round(L / (1 + np.exp(-k * (x - x0))), 3)
def read_context_files(self, n):
max_context_length = max(self.context_lengths)
contexts = []
f = open_file(self.haystack_file, 'r')
for _ in range(n):
context = ""
toks = 0
while toks < max_context_length:
text = json.loads(f.readline())['text']
context += text
toks += len(self.enc.encode(text))
contexts.append(context)
return contexts
def encode_and_trim(self, context, context_length):
tokens = self.enc.encode(context)
if len(tokens) > context_length:
context = self.enc.decode(tokens[:context_length])
return context
def create_contexts(self, needle_rnd_number, insert_needle, random_city, trim_context, context_length, depth_percent, seed):
if self.save_results:
if self.result_exists(context_length, depth_percent):
return
needle = self.needle.format(city=random_city, rnd_number=needle_rnd_number)
question = self.retrieval_question.format(random_city)
if not insert_needle:
needle = " " #replace needle with a space
context = self.generate_context(needle, trim_context, context_length, depth_percent)
results = {
'context' : context,
'context_length' : int(context_length),
'depth_percent' : float(depth_percent),
'needle' : needle,
'question' : question,
'insert_needle' : insert_needle,
'needle_rnd_number' : needle_rnd_number,
'seed': seed,
}
return results
def insert_needle(self, needle, context, depth_percent, context_length):
tokens_needle = self.enc_tiktoken.encode(needle)
tokens_context = self.enc_tiktoken.encode(context)
# Reducing the context length by 150 buffer. This is to account for system message, the user question, and response.
context_length -= self.final_context_length_buffer
# If your context + needle are longer than the context length (which it will be), then reduce tokens from the context by the needle length
if len(tokens_context) + len(tokens_needle) > context_length:
tokens_context = tokens_context[:context_length - len(tokens_needle)]
if depth_percent == 100:
# If your depth percent is 100 (which means your needle is the last thing in the doc), throw it at the end
tokens_new_context = tokens_context + tokens_needle
else:
# Go get the position (in terms of tokens) to insert your needle
insertion_point = int(len(tokens_context) * (depth_percent / 100))
# tokens_new_context represents the tokens before the needle
tokens_new_context = tokens_context[:insertion_point]
# We want to make sure that we place our needle at a sentence break so we first see what token a '.' is
period_tokens = self.enc_tiktoken.encode('.')
# Then we iteration backwards until we find the first period
while tokens_new_context and tokens_new_context[-1] not in period_tokens:
insertion_point -= 1
tokens_new_context = tokens_context[:insertion_point]
# Once we get there, then add in your needle, and stick the rest of your context in on the other end.
# Now we have a needle in a haystack
tokens_new_context += tokens_needle + tokens_context[insertion_point:]
# Convert back to a string and return it
new_context = self.enc_tiktoken.decode(tokens_new_context)
return new_context
def generate_context(self, needle, trim_context, context_length, depth_percent):
context = self.insert_needle(needle, trim_context, depth_percent, context_length)
return context
def compute_max_input_length(self, context_length, buffer=1024):
block_size = self.model.block_size
context_length += buffer
context_length = math.ceil(context_length / block_size) * block_size
return int(context_length)
def run_test(self):
fs = gcsfs.GCSFileSystem()
contexts = []
template = self.OURS_TEMPLATE
def _key_from_result(result):
return (result['context_length'], result['depth_percent'], result['seed'])
results = []
completed = set()
def exists(fname):
if fname.startswith('gs://'):
return fs.exists(fname)
else:
return os.path.exists(fname)
if exists(FLAGS.output_file):
with open_file(FLAGS.output_file, 'r') as f:
results = json.load(f)
completed = set([_key_from_result(result) for result in results])
print('completed', len(completed))
full_contexts = self.read_context_files(FLAGS.n_rounds)
full_tokens = [self.enc.encode(full_context) for full_context in tqdm(full_contexts)]
start = time.time()
for context_length in self.context_lengths:
trim_contexts = [self.enc.decode(full_token[:context_length]) for full_token in tqdm(full_tokens)]
max_input_length = self.compute_max_input_length(context_length)
contexts = []
for depth_percent in self.document_depth_percents:
for i in range(FLAGS.n_rounds):
if (int(context_length), float(depth_percent), i) in completed:
continue
random_city = random.choice(LLMNeedleHaystackTester.RANDOM_NEEDLE_CITIES)
insert_needle = True
needle_rnd_number = str(self.generate_random_number(self.rnd_number_digits))
print("context length: " + str(context_length))
print("depth_percent : " + str(depth_percent))
context = self.create_contexts(needle_rnd_number, insert_needle, random_city, trim_contexts[i], context_length, depth_percent, i)
contexts.append(context)
if len(contexts) == 0:
continue
B = FLAGS.max_tokens_per_batch / (max_input_length + self.model.block_size)
B = int(B / self.model.data_dim) * self.model.data_dim
if B < self.model.data_dim:
B = self.model.data_dim
elif B > len(contexts):
B = int(math.ceil(len(contexts) / self.model.data_dim) * self.model.data_dim)
if len(contexts) % B == 0:
n_pad = 0
else:
n_pad = B - len(contexts) % B
for _ in range(n_pad):
contexts.insert(0, contexts[0])
pbar = tqdm(total=len(contexts))
for i in range(0, len(contexts), B):
contexts_i = contexts[i:i + B]
prompts = [
template.format(context=context['context'], question=context['question'])
for context in contexts_i
]
outs = self.model(prompts, max_input_length)
for j, (context, out) in enumerate(zip(contexts_i, outs)):
if i + j < n_pad:
continue
results.append({
'context_length': context['context_length'],
'depth_percent': context['depth_percent'],
'response': out,
'answer': context['needle_rnd_number'],
'correct': context['needle_rnd_number'] in out,
'seed': context['seed'],
})
print(results[-1])
if jax.process_index() == 0:
with open_file(FLAGS.output_file, 'w') as f:
json.dump(results, f)
pbar.update(len(contexts_i))
pbar.close()
print('elapsed', time.time() - start)
print('done')
def print_start_test_summary(self):
print ("\n")
print ("Starting Needle In A Haystack Testing...")
print (f"- Context Lengths: {len(self.context_lengths)}, Min: {min(self.context_lengths)}, Max: {max(self.context_lengths)}")
print (f"- Document Depths: {len(self.document_depth_percents)}, Min: {min(self.document_depth_percents)}%, Max: {max(self.document_depth_percents)}%")
print (f"- Needle: {self.needle.strip()}")
print ("\n\n")
def start_test(self):
if self.print_ongoing_status:
self.print_start_test_summary()
self.run_test()
class Sampler:
def __init__(self):
self.mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
self.prefix_tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer, truncation_side='left', padding_side='left')
self.tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer)
self.sharded_rng = next_rng()
self._load_model()
@property
def block_size(self):
# return 2 * max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size)
return max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) * self.mesh.shape['sp']
@property
def data_dim(self):
return self.mesh.shape['dp'] * self.mesh.shape['fsdp']
def _load_model(self):
if FLAGS.load_llama_config != '':
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
updates = LLaMAConfig(**FLAGS.llama)
llama_config.update(dict(
scan_attention=updates.scan_attention,
scan_mlp=updates.scan_mlp,
scan_query_chunk_size=updates.scan_query_chunk_size,
scan_key_chunk_size=updates.scan_key_chunk_size,
scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
scan_layers=updates.scan_layers,
param_scan_axis=updates.param_scan_axis,
))
else:
llama_config = LLaMAConfig(**FLAGS.llama)
if FLAGS.update_llama_config != '':
llama_config.update(dict(eval(FLAGS.update_llama_config)))
llama_config.update(dict(
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
))
llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))
self.config = llama_config
with jax.default_device(jax.devices("cpu")[0]):
_, self.params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
)
self.model = FlaxLLaMAForCausalLM(
llama_config,
input_shape=(512, self.block_size),
seed=FLAGS.seed,
_do_init=False,
dtype=get_float_dtype_by_name(FLAGS.dtype),
)
self.model_ps = match_partition_rules(
LLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params
)
shard_fns, _ = make_shard_and_gather_fns(
self.model_ps, get_float_dtype_by_name(FLAGS.dtype)
)
with self.mesh:
self.params = tree_apply(shard_fns, self.params)
@cached_property
def _forward_generate(self):
def fn(params, rng, batch):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
rng_generator = JaxRNG(rng)
output = self.model.generate(
batch['input_ids'],
attention_mask=batch['attention_mask'],
params=params['params'],
prng_key=rng_generator(),
generation_config=GenerationConfig(
max_new_tokens=self.block_size,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
temperature=0.,
do_sample=False,
num_beams=1,
top_k=50,
top_p=1.0,
)
).sequences[:, batch['input_ids'].shape[1]:]
return output, rng_generator()
return pjit(
fn,
in_shardings=(self.model_ps, PS(), PS()),
out_shardings=(PS(), PS())
)
def __call__(self, prompts, max_input_length):
inputs = self.prefix_tokenizer(
prompts,
padding='max_length',
truncation=True,
max_length=max_input_length,
return_tensors='np'
)
batch = dict(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask
)
with self.mesh:
output, self.sharded_rng = self._forward_generate(
self.params, self.sharded_rng, batch
)
output = jax.device_get(output)
output_text = []
for text in list(self.tokenizer.batch_decode(output, skip_special_tokens=True)):
if self.tokenizer.eos_token in text:
text = text.split(self.tokenizer.eos_token, maxsplit=1)[0]
output_text.append(text)
return output_text
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
set_random_seed(FLAGS.seed)
ht = LLMNeedleHaystackTester(
haystack_file=FLAGS.haystack_file,
context_lengths_min=FLAGS.context_lengths_min,
context_lengths_max=FLAGS.context_lengths_max,
context_lengths_num_intervals=FLAGS.n_context_length_intervals,
document_depth_percent_intervals=FLAGS.n_document_depth_intervals,
)
ht.start_test()
if __name__ == "__main__":
run(main)
================================================
FILE: scripts/eval_needle_multi.py
================================================
from absl.app import run
import glob
import time
import json
import math
import os
from tqdm import tqdm
import random
from functools import cached_property
import numpy as np
import jax
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
import gcsfs
import tiktoken
from transformers import GenerationConfig, AutoTokenizer
from tux import (
define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,
set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng,
match_partition_rules, make_shard_and_gather_fns,
with_sharding_constraint, tree_apply, open_file
)
from lwm.llama import LLaMAConfig, FlaxLLaMAForCausalLM
FLAGS, FLAGS_DEF = define_flags_with_default(
haystack_file="",
max_tokens_per_batch=2000000,
output_file="results.json",
context_lengths_min=1000,
context_lengths_max=32000,
n_context_length_intervals=3,
n_document_depth_intervals=3,
n_rounds=2,
n_needles_total=4,
n_needles_retrieve=4,
seed=1234,
mesh_dim='1,-1,1,1',
dtype='fp32',
load_llama_config='',
update_llama_config='',
load_checkpoint='',
tokenizer='LargeWorldModel/LWM-Text-1M',
checkpointer=StreamingCheckpointer.get_default_config(),
llama=LLaMAConfig.get_default_config(),
jax_distributed=JaxDistributedConfig.get_default_config(),
)
class LLMNeedleHaystackTester:
OURS_TEMPLATE = "You are a helpful assistant. USER: {context} {question} Don't give information outside the document. ASSISTANT: "
RANDOM_NEEDLE_CITIES = [
'Chicago', 'Yangon', 'Antananarivo', 'Colombo', 'Almaty', 'Sydney', 'Chicago', 'Mexico City',
'Seattle', 'Lagos', 'Amsterdam', 'Belgrade', 'Cairo', 'Baghdad', 'Damascus', 'Kigali', 'Dakar',
'Dakar', 'Sofia', 'Kigali', 'Victoria', 'Tashkent', 'Mumbai', 'Barcelona', 'Almaty', 'Amman',
'Toronto', 'Bratislava', 'Johannesburg', 'Thimphu', 'Bangkok', 'Santiago', 'Cairo', 'San Francisco',
'Lagos', 'Amsterdam', 'Paris', 'Rabat', 'Santiago', 'Copenhagen', 'Madrid', 'Kigali',
'Ho Chi Minh City', 'Sarajevo', 'Delhi', 'Istanbul', 'Ho Chi Minh City', 'Khartoum', 'Helsinki',
'Doha', 'Istanbul', 'Kuala Lumpur', 'Budapest', 'Shanghai', 'Moscow', 'Los Angeles', 'Oslo',
'Johannesburg', 'Berlin', 'Bangalore', 'Tokyo', 'Melbourne', 'Barcelona', 'Chicago', 'Port Louis',
'Lisbon', 'Nairobi', 'Kampala', 'Lima', 'Maputo', 'Vancouver', 'Dubai', 'Khartoum', 'Jakarta',
'Madrid', 'Yerevan', 'Beirut', 'Athens', 'Chicago', 'Paris', 'Bucharest', 'Copenhagen', 'Brussels',
'Damascus', 'Seattle', 'Los Angeles', 'Yerevan', 'Victoria', 'Tunis', 'Astana', 'Seoul',
'Buenos Aires', 'Bangkok', 'Colombo', 'Brussels', 'Khartoum', 'Doha', 'San Francisco', 'Vienna', 'Jakarta'
]
def __init__(self,
needle="",
haystack_file="",
retrieval_question="What are the special magic numbers for {}?",
results_version = 1,
rnd_number_digits = 7,
context_lengths_min = 1000,
context_lengths_max = 126000,
context_lengths_num_intervals = 10,
document_depth_percent_min = 0,
document_depth_percent_max = 100,
document_depth_percent_intervals = 10,
document_depth_percent_interval_type = "linear",
save_results = False,
final_context_length_buffer = 200,
print_ongoing_status = True):
needle="\nThe special magic {city} number is: {rnd_number}\n"
self.needle = needle
if not needle or not haystack_file or not retrieval_question:
raise ValueError("Needle, haystack, and retrieval_question must be provided.")
self.rnd_number_digits = rnd_number_digits
self.context_lengths_num_intervals = context_lengths_num_intervals
self.document_depth_percent_intervals = document_depth_percent_intervals
self.haystack_file = haystack_file
self.retrieval_question = retrieval_question
self.results_version = results_version
self.save_results = save_results
self.final_context_length_buffer = final_context_length_buffer
self.print_ongoing_status = print_ongoing_status
self.testing_results = []
self.context_lengths = np.round(np.linspace(context_lengths_min, context_lengths_max, num=context_lengths_num_intervals, endpoint=True)).astype(int)
self.context_lengths = self.context_lengths.tolist()
if document_depth_percent_interval_type == 'linear':
self.document_depth_percents = np.round(np.linspace(document_depth_percent_min, document_depth_percent_max, num=document_depth_percent_intervals, endpoint=True)).astype(int)
elif document_depth_percent_interval_type == 'sigmoid':
self.document_depth_percents = [self.logistic(x) for x in np.linspace(document_depth_percent_min, document_depth_percent_max, document_depth_percent_intervals)]
else:
raise ValueError(f"Unsupported document_depth_percent_interval_type: {document_depth_percent_interval_type}")
self.document_depth_percents = self.document_depth_percents.tolist()
self.model = Sampler()
self.enc = AutoTokenizer.from_pretrained(FLAGS.tokenizer)
self.enc_tiktoken = tiktoken.encoding_for_model("gpt-4-1106-preview")
def generate_random_number(self, num_digits):
lower_bound = 10**(num_digits - 1)
upper_bound = 10**num_digits - 1
return random.randint(lower_bound, upper_bound)
def logistic(self, x, L=100, x0=50, k=.1):
if x == 0:
return 0
if x == 100:
return 100
return np.round(L / (1 + np.exp(-k * (x - x0))), 3)
def read_context_files(self, n):
max_context_length = max(self.context_lengths)
contexts = []
f = open_file(self.haystack_file, 'r')
for i in range(n):
context = ""
while len(self.enc.encode(context)) < max_context_length:
context += json.loads(f.readline())['text']
contexts.append(context)
return contexts
def encode_and_trim(self, context, context_length):
tokens = self.enc.encode(context)
if len(tokens) > context_length:
context = self.enc.decode(tokens[:context_length])
return context
def create_contexts(self, needles_info, random_cities_retrieve, context, context_length, seed):
assert all([random_city in needles_info for random_city in random_cities_retrieve])
for random_city, (needle_rnd_number, depth_percent) in needles_info.items():
context = self.generate_context(
self.needle.format(city=random_city, rnd_number=needle_rnd_number),
context, context_length, depth_percent
)
if len(random_cities_retrieve) == 1:
question = f"What is the special magic number for {random_cities_retrieve[0]}?"
else:
q = ', '.join(random_cities_retrieve[:-1]) + ', and ' + random_cities_retrieve[-1]
question = self.retrieval_question.format(q)
results = {
'context' : context,
'context_length' : int(context_length),
'needles_info': needles_info,
'question' : question,
'cities_to_retrieve' : random_cities_retrieve,
'seed': seed,
}
return results
def insert_needle(self, needle, context, depth_percent, context_length):
tokens_needle = self.enc_tiktoken.encode(needle)
tokens_context = self.enc_tiktoken.encode(context)
# Reducing the context length by 150 buffer. This is to account for system message, the user question, and response.
context_length -= self.final_context_length_buffer
# If your context + needle are longer than the context length (which it will be), then reduce tokens from the context by the needle length
if len(tokens_context) + len(tokens_needle) > context_length:
tokens_context = tokens_context[:context_length - len(tokens_needle)]
if depth_percent == 100:
# If your depth percent is 100 (which means your needle is the last thing in the doc), throw it at the end
tokens_new_context = tokens_context + tokens_needle
else:
# Go get the position (in terms of tokens) to insert your needle
insertion_point = int(len(tokens_context) * (depth_percent / 100))
# tokens_new_context represents the tokens before the needle
tokens_new_context = tokens_context[:insertion_point]
# We want to make sure that we place our needle at a sentence break so we first see what token a '.' is
period_tokens = self.enc_tiktoken.encode('.')
# Then we iteration backwards until we find the first period
while tokens_new_context and tokens_new_context[-1] not in period_tokens:
insertion_point -= 1
tokens_new_context = tokens_context[:insertion_point]
# Once we get there, then add in your needle, and stick the rest of your context in on the other end.
# Now we have a needle in a haystack
tokens_new_context += tokens_needle + tokens_context[insertion_point:]
# Convert back to a string and return it
new_context = self.enc_tiktoken.decode(tokens_new_context)
return new_context
def generate_context(self, needle, trim_context, context_length, depth_percent):
context = self.insert_needle(needle, trim_context, depth_percent, context_length)
return context
def compute_max_input_length(self, context_length, buffer=1024):
block_size = self.model.block_size
context_length += buffer
# context_length = 2 ** math.ceil(math.log2(context_length))
context_length = math.ceil(context_length / block_size) * block_size
return int(context_length)
def run_test(self):
fs = gcsfs.GCSFileSystem()
contexts = []
template = self.OURS_TEMPLATE
def _key_from_result(result):
return (result['context_length'], result['depth_percent'], result['seed'])
results = []
completed = set()
def exists(fname):
if fname.startswith('gs://'):
return fs.exists(fname)
else:
return os.path.exists(fname)
if exists(FLAGS.output_file):
with open_file(FLAGS.output_file, 'r') as f:
results = json.load(f)
completed = set([_key_from_result(result) for result in results])
print('completed', len(completed))
full_contexts = self.read_context_files(FLAGS.n_rounds)
full_tokens = [self.enc.encode(full_context) for full_context in full_contexts]
start = time.time()
for context_length in self.context_lengths:
trim_contexts = [self.enc.decode(full_token[:context_length]) for full_token in full_tokens]
max_input_length = self.compute_max_input_length(context_length)
contexts = []
for i in range(FLAGS.n_rounds):
if (int(context_length), i) in completed:
continue
random_cities = random.sample(LLMNeedleHaystackTester.RANDOM_NEEDLE_CITIES, FLAGS.n_needles_total)
document_depths = random.sample(self.document_depth_percents, FLAGS.n_needles_total)
random_cities_retrieve = random.sample(random_cities, FLAGS.n_needles_retrieve)
needles_info = {}
for random_city, depth_percent in zip(random_cities, document_depths):
needles_info[random_city] = (
str(self.generate_random_number(self.rnd_number_digits)),
depth_percent
)
context = self.create_contexts(needles_info, random_cities_retrieve, trim_contexts[i], context_length, i)
contexts.append(context)
if len(contexts) == 0:
continue
B = FLAGS.max_tokens_per_batch / (max_input_length + self.model.block_size)
B = int(B / self.model.data_dim) * self.model.data_dim
if B < self.model.data_dim:
B = self.model.data_dim
elif B > len(contexts):
B = int(math.ceil(len(contexts) / self.model.data_dim) * self.model.data_dim)
n_pad = B - len(contexts) % B
for _ in range(n_pad):
contexts.insert(0, contexts[0])
pbar = tqdm(total=len(contexts))
for i in range(0, len(contexts), B):
contexts_i = contexts[i:i + B]
prompts = [
template.format(context=context['context'], question=context['question'])
for context in contexts_i
]
outs = self.model(prompts, max_input_length)
for j, (context, out) in enumerate(zip(contexts_i, outs)):
if i + j < n_pad:
continue
rnd_nums_to_retrieve = [
context['needles_info'][city][0] for city in context['cities_to_retrieve']
]
results.append({
'context_length': context['context_length'],
'needles_info': context['needles_info'],
'question': context['question'],
'answer': rnd_nums_to_retrieve,
'response': out,
'correct': [rnd_num in out for rnd_num in rnd_nums_to_retrieve],
'seed': context['seed'],
})
print(results[-1]['correct'], out, rnd_nums_to_retrieve)
if jax.process_index() == 0:
with open_file(FLAGS.output_file, 'w') as f:
json.dump(results, f)
pbar.update(len(contexts_i))
pbar.close()
print('elapsed', time.time() - start)
print('done')
def print_start_test_summary(self):
print ("\n")
print ("Starting Needle In A Haystack Testing...")
print (f"- Context Lengths: {len(self.context_lengths)}, Min: {min(self.context_lengths)}, Max: {max(self.context_lengths)}")
print (f"- Document Depths: {len(self.document_depth_percents)}, Min: {min(self.document_depth_percents)}%, Max: {max(self.document_depth_percents)}%")
print (f"- Needle: {self.needle.strip()}")
print ("\n\n")
def start_test(self):
if self.print_ongoing_status:
self.print_start_test_summary()
self.run_test()
class Sampler:
def __init__(self):
self.mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
self.prefix_tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer, truncation_side='left', padding_side='left')
self.tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer)
self.sharded_rng = next_rng()
self._load_model()
@property
def block_size(self):
# return 2 * max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size)
return max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) * self.mesh.shape['sp']
@property
def data_dim(self):
return self.mesh.shape['dp'] * self.mesh.shape['fsdp']
def _load_model(self):
if FLAGS.load_llama_config != '':
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
updates = LLaMAConfig(**FLAGS.llama)
llama_config.update(dict(
scan_attention=updates.scan_attention,
scan_mlp=updates.scan_mlp,
scan_query_chunk_size=updates.scan_query_chunk_size,
scan_key_chunk_size=updates.scan_key_chunk_size,
scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
scan_layers=updates.scan_layers,
param_scan_axis=updates.param_scan_axis,
))
else:
llama_config = LLaMAConfig(**FLAGS.llama)
if FLAGS.update_llama_config != '':
llama_config.update(dict(eval(FLAGS.update_llama_config)))
llama_config.update(dict(
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
))
llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))
self.config = llama_config
with jax.default_device(jax.devices("cpu")[0]):
_, self.params = StreamingCheckpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30
)
self.model = FlaxLLaMAForCausalLM(
llama_config,
input_shape=(512, self.block_size),
seed=FLAGS.seed,
_do_init=False,
dtype=get_float_dtype_by_name(FLAGS.dtype),
)
self.model_ps = match_partition_rules(
LLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params
)
shard_fns, _ = make_shard_and_gather_fns(
self.model_ps, get_float_dtype_by_name(FLAGS.dtype)
)
with self.mesh:
self.params = tree_apply(shard_fns, self.params)
@cached_property
def _forward_generate(self):
def fn(params, rng, batch):
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))
rng_generator = JaxRNG(rng)
output = self.model.generate(
batch['input_ids'],
attention_mask=batch['attention_mask'],
params=params['params'],
prng_key=rng_generator(),
generation_config=GenerationConfig(
max_new_tokens=self.block_size,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
temperature=0.,
do_sample=False,
num_beams=1,
top_k=50,
top_p=1.0,
)
).sequences[:, batch['input_ids'].shape[1]:]
return output, rng_generator()
return pjit(
fn,
in_shardings=(self.model_ps, PS(), PS()),
out_shardings=(PS(), PS())
)
def __call__(self, prompts, max_input_length):
inputs = self.prefix_tokenizer(
prompts,
padding='max_length',
truncation=True,
max_length=max_input_length,
return_tensors='np'
)
batch = dict(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask
)
with self.mesh:
output, self.sharded_rng = self._forward_generate(
self.params, self.sharded_rng, batch
)
output = jax.device_get(output)
output_text = []
for text in list(self.tokenizer.batch_decode(output, skip_special_tokens=True)):
if self.tokenizer.eos_token in text:
text = text.split(self.tokenizer.eos_token, maxsplit=1)[0]
output_text.append(text)
return output_text
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
set_random_seed(FLAGS.seed)
ht = LLMNeedleHaystackTester(
haystack_file=FLAGS.haystack_file,
context_lengths_min=FLAGS.context_lengths_min,
context_lengths_max=FLAGS.context_lengths_max,
context_lengths_num_intervals=FLAGS.n_context_length_intervals,
document_depth_percent_intervals=FLAGS.n_document_depth_intervals,
)
ht.start_test()
if __name__ == "__main__":
run(main)
================================================
FILE: scripts/run_eval_needle.sh
================================================
#! /bin/bash
export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
export llama_tokenizer_path="LargeWorldModel/LWM-Text-1M"
export lwm_text_checkpoint=""
# jsonl file containing text for haystack. Each line should be a json
# with a single key "text" containing the text.
export haystack_file=""
export output_file=""
python3 -u scripts/eval_needle.py \
--mesh_dim='!1,-1,4,1' \
--dtype='fp32' \
--load_llama_config='7b' \
--update_llama_config="dict(theta=10000000,max_sequence_length=131072,scan_attention=True,scan_query_chunk_size=1024,scan_key_chunk_size=1024,scan_mlp=True,scan_mlp_chunk_size=1024,scan_layers=True)" \
--load_checkpoint="params::$lwm_text_checkpoint" \
--tokenizer="$llama_tokenizer_path" \
--max_tokens_per_batch=5000 \
--output_file="$output_file" \
--haystack_file="$haystack_file" \
--context_lengths_min=1000 \
--context_lengths_max=10000 \
--n_context_length_intervals=20 \
--n_document_depth_intervals=20 \
--n_rounds=3
read
================================================
FILE: scripts/run_eval_needle_multi.sh
================================================
#! /bin/bash
export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
export llama_tokenizer_path="LargeWorldModel/LWM-Text-1M"
export lwm_text_checkpoint=""
# jsonl file containing text for haystack. Each line should be a json
# with a single key "text" containing the text.
export haystack_file=""
export output_file=""
python3 -u scripts/eval_needle_multi.py \
--mesh_dim='!1,1,-1,1' \
--dtype='fp32' \
--load_llama_config='7b' \
--update_llama_config="dict(theta=10000000,max_sequence_length=131072,scan_attention=True,scan_query_chunk_size=1024,scan_key_chunk_size=1024,scan_mlp=True,scan_mlp_chunk_size=1024,scan_layers=True)" \
--load_checkpoint="params::$lwm_text_checkpoint" \
--tokenizer="$llama_tokenizer_path" \
--max_tokens_per_batch=5000 \
--output_file="$output_file" \
--haystack_file="$haystack_file" \
--context_lengths_min=1000 \
--context_lengths_max=10000 \
--n_context_length_intervals=10 \
--n_document_depth_intervals=10 \
--n_needles_total=4 \
--n_needles_retrieve=2 \
--n_rounds=10
read
================================================
FILE: scripts/run_sample_image.sh
================================================
#! /bin/bash
export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
export llama_tokenizer_path="LargeWorldModel/LWM-Text-1M"
export vqgan_checkpoint=""
export lwm_checkpoint=""
# Relevant params
# --temperature_*: Temperature that is applied to each of the logits
# --top_k_*: Only sample from the tokens with the top k logits
# --cfg_scale_*: Classifier-free guidance scale for each modality
# --n_frames: Number of frames to generate. For images specify 1.
python3 -u -m lwm.vision_generation \
--prompt='Fireworks over the city' \
--output_file='fireworks.png' \
--temperature_image=1.0 \
--top_k_image=8192 \
--cfg_scale_image=5.0 \
--vqgan_checkpoint="$vqgan_checkpoint" \
--n_frames=1 \
--mesh_dim='!1,1,-1,1' \
--dtype='fp32' \
--load_llama_config='7b' \
--update_llama_config="dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \
--load_checkpoint="params::$lwm_checkpoint" \
--tokenizer="$llama_tokenizer_path"
read
================================================
FILE: scripts/run_sample_video.sh
================================================
#! /bin/bash
export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
export llama_tokenizer_path="LargeWorldModel/LWM-Text-1M"
export vqgan_checkpoint=""
export lwm_checkpoint=""
# Relevant params
# --temperature_*: Temperature that is applied to each of the logits
# --top_k_*: Only sample from the tokens with the top k logits
# --cfg_scale_*: Classifier-free guidance scale for each modality
# --n_frames: Number of frames to generate
python3 -u -m lwm.vision_generation \
--prompt='Fireworks over the city' \
--output_file='fireworks.mp4' \
--temperature_image=1.0 \
--temperature_video=1.0 \
--top_k_image=8192 \
--top_k_video=1000 \
--cfg_scale_image=5.0 \
--cfg_scale_video=1.0 \
--vqgan_checkpoint="$vqgan_checkpoint" \
--n_frames=8 \
--mesh_dim='!1,1,-1,1' \
--dtype='fp32' \
--load_llama_config='7b' \
--update_llama_config="dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \
--load_checkpoint="params::$lwm_checkpoint" \
--tokenizer="$llama_tokenizer_path"
read
================================================
FILE: scripts/run_train_text.sh
================================================
#! /bin/bash
export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
export llama_tokenizer_path="LargeWorldModel/LWM-Text-1M"
export dataset_path=""
export output_dir=""
export project_id='lwm'
export experiment_note=''
export experiment_id='example-text-train'
# mesh_dim: dp, fsdp, tp, sp
python3 -u -m lwm.train \
--modality='text' \
--mesh_dim='!1,-1,2,2' \
--dtype='fp32' \
--total_steps=200 \
--log_freq=1 \
--save_model_freq=0 \
--save_milestone_freq=10 \
--load_llama_config='debug' \
--update_llama_config="dict(theta=10000,max_sequence_length=2048,scan_attention=True,scan_query_chunk_size=256,scan_key_chunk_size=512,scan_mlp=True,scan_mlp_chunk_size=1024,scan_layers=True)" \
--tokenizer="$llama_tokenizer_path" \
--optimizer.type='adamw' \
--optimizer.accumulate_gradient_steps=1 \
--optimizer.adamw_optimizer.weight_decay=0.1 \
--optimizer.adamw_optimizer.lr=8e-5 \
--optimizer.adamw_optimizer.end_lr=8e-5 \
--optimizer.adamw_optimizer.lr_warmup_steps=5 \
--optimizer.adamw_optimizer.lr_decay_steps=200 \
--use_data_sharded_loader=True \
--train_dataset.type='json' \
--train_dataset.text_processor.fields='text' \
--train_dataset.json_dataset.path="$dataset_path" \
--train_dataset.json_dataset.seq_length=2048 \
--train_dataset.json_dataset.batch_size=1024 \
--train_dataset.json_dataset.tokenizer_processes=16 \
--train_dataset.json_dataset.use_data_sharded_loader=True \
--checkpointer.save_optimizer_state=True \
--autoresume=False \
--logger.append_uuid=False \
--logger.online=False \
--logger.project_id="$project_id" \
--logger.experiment_id="$experiment_id" \
--logger.experiment_note="$experiment_note" \
--logger.output_dir="$output_dir" \
--logger.wandb_dir="$HOME/experiment_output/$project_id"
read
================================================
FILE: scripts/run_train_vision_text.sh
================================================
#! /bin/bash
export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
export llama_tokenizer_path="LargeWorldModel/LWM-Text-1M"
export dataset_path=""
export output_dir=""
export project_id='lwm'
export experiment_note=''
export experiment_id='example-vision-text-train'
# mesh_dim: dp, fsdp, tp, sp
python3 -u -m lwm.train \
--modality='vision,text' \
--mesh_dim='!1,-1,2,2' \
--dtype='fp32' \
--total_steps=200 \
--log_freq=1 \
--save_model_freq=0 \
--save_milestone_freq=10 \
--load_llama_config='debug' \
--update_llama_config="dict(theta=50000000,max_sequence_length=2048,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,scan_mlp=True,scan_mlp_chunk_size=8192,scan_layers=True)" \
--tokenizer="$llama_tokenizer_path" \
--optimizer.type='adamw' \
--optimizer.accumulate_gradient_steps=1 \
--optimizer.adamw_optimizer.weight_decay=0.1 \
--optimizer.adamw_optimizer.lr=8e-5 \
--optimizer.adamw_optimizer.end_lr=8e-5 \
--optimizer.adamw_optimizer.lr_warmup_steps=5 \
--optimizer.adamw_optimizer.lr_decay_steps=200 \
--use_data_sharded_loader=True \
--train_dataset.type='json_vision' \
--train_dataset.vision_text_processor.fields_from_example='fields' \
--train_dataset.vision_text_processor.max_n_frames=4 \
--train_dataset.json_vision_dataset.mode="no_pad" \
--train_dataset.json_vision_dataset.path="$dataset_path" \
--train_dataset.json_vision_dataset.seq_length=2048 \
--train_dataset.json_vision_dataset.batch_size=8 \
--train_dataset.json_vision_dataset.tokenizer_processes=4 \
--train_dataset.json_vision_dataset.tokenizer_parallel_chunk_size=2 \
--train_dataset.json_vision_dataset.tokenizer_parallel_batch_size=8 \
--train_dataset.json_vision_dataset.use_data_sharded_loader=True \
--checkpointer.save_optimizer_state=True \
--autoresume=False \
--logger.append_uuid=False \
--logger.online=False \
--logger.project_id="$project_id" \
--logger.experiment_id="$experiment_id" \
--logger.experiment_note="$experiment_note" \
--logger.output_dir="$output_dir" \
--logger.wandb_dir="$HOME/experiment_output/$project_id"
read
================================================
FILE: scripts/run_vision_chat.sh
================================================
#! /bin/bash
export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
export llama_tokenizer_path="LargeWorldModel/LWM-Text-1M"
export vqgan_checkpoint=""
export lwm_checkpoint=""
export input_file=""
# Relevant params
# --input_file: A given image file (png or jpg) or video file (any video format support by decord, e.g. mp4)
# --max_n_frames: Maximum number of frames to process. If the video is longer than max_n_frames frames, it uniformly samples max_n_frames frames from the video
python3 -u -m lwm.vision_chat \
--prompt="What is the video about?" \
--input_file="$input_file" \
--vqgan_checkpoint="$vqgan_checkpoint" \
--mesh_dim='!1,1,-1,1' \
--dtype='fp32' \
--load_llama_config='7b' \
--max_n_frames=8 \
--update_llama_config="dict(sample_mode='text',theta=50000000,max_sequence_length=131072,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=2048,scan_layers=True)" \
--load_checkpoint="params::$lwm_checkpoint" \
--tokenizer="$llama_tokenizer_path" \
2>&1 | tee ~/output.log
read
================================================
FILE: scripts/sample_pyt.py
================================================
import argparse
from transformers import LlamaForCausalLM, LlamaTokenizer
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', type=str, default='LargeWorldModel/LWM-Text-Chat-256K')
args = parser.parse_args()
model = LlamaForCausalLM.from_pretrained(args.model)
tokenizer = LlamaTokenizer.from_pretrained(args.model)
# template only relevant for chat models. non-chat models do not need this
TEMPLATE = "You are a helpful assistant. USER: {} ASSISTANT:"
question = "What is the capital of France?"
prompt = TEMPLATE.format(question)
inputs = tokenizer(prompt, return_tensors="pt")
generate_ids = model.generate(inputs.input_ids, max_length=300)
output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True)[0]
print(output)
================================================
FILE: tpu_requirements.sh
================================================
#! /bin/bash
sudo apt-get update && sudo apt-get install -y \
build-essential \
python-is-python3 \
tmux \
htop \
git \
ffmpeg
# Update pip
pip install --upgrade pip
# Python dependencies
cat > $HOME/tpu_requirements.txt <<- EndOfFile
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
jax[tpu]==0.4.29
flax==0.8.4
optax==0.2.2
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.0.0
transformers==4.40.0
ringattention @ git+https://github.com/haoliuhl/ringattention.git
datasets
einops
tqdm
ml_collections
wandb
gcsfs
requests
typing-extensions
sentencepiece
tux @ git+https://github.com/haoliuhl/tux.git
Pillow
ffmpeg-python
ipdb
imageio[ffmpeg]
opencv-python
decord
ffmpeg-python
h5py
psutil
EndOfFile
pip install --upgrade -r $HOME/tpu_requirements.txt
# vim configurations
cat > $HOME/.vimrc <<- EndOfFile
set tabstop=4
set shiftwidth=4
set softtabstop=4
set expandtab
set backspace=indent,eol,start
syntax on
EndOfFile
# tmux configurations
cat > $HOME/.tmux.conf <<- EndOfFile
bind r source-file ~/.tmux.conf \; display-message "█▓░ ~/.tmux.conf reloaded."
# Enable colors, https://github.com/tmux/tmux/wiki/FAQ
set -g default-terminal "tmux-256color"
# start with window 1 (instead of 0)
set -g base-index 1
setw -g pane-base-index 1
set -g prefix C-a
set -g set-titles on
set -g set-titles-string '#(whoami)::#h::#(curl ipecho.net/plain;echo)'
# Status bar customization
set -g status-interval 5
set -g status-left-length 90
set -g status-right-length 60
set -g status-justify left
# send the prefix to client inside window (ala nested sessions)
bind-key a send-prefix
bind-key x kill-pane
# auto reorder
set-option -g renumber-windows on
# default window name
set -g status-left "#[fg=green,bg=colour236] #S "
# default statusbar colors
set-option -g status-style fg=yellow,dim,bg=colour235
# default window title colors
set-window-option -g window-status-style fg=yellow,bg=colour236,dim
# active window title colors
set-window-option -g window-status-current-style fg=brightred,bg=colour236
# basename as window title https://stackoverflow.com/a/37136828
set-window-option -g window-status-current-format '#{window_index} #{pane_current_command} #(echo "#{pane_current_path}" | rev | cut -d'/' -f-3 | rev)'
set-window-option -g window-status-format '#{window_index} #{pane_current_command} #(echo "#{pane_current_path}" | rev | cut -d'/' -f-3 | rev)'
# pane border
set-option -g pane-border-style fg=white #base2
set-option -g pane-active-border-style fg=brightcyan #base1
# enable mouse click
set -g mouse on
# keep window on
set -g remain-on-exit on
# Longer scrollback history
set -g history-limit 50000
# Scroll position indicator
set -g mode-style bg=colour235,fg=colour245
# SSH agent forwarding
# set-environment -g SSH_AUTH_SOCK $SSH_AUTH_SOCK
if-shell '[ -n $SSH_AUTH_SOCK ]' " \
set-option -sg update-environment \"DISPLAY WINDOWID XAUTHORITY\"; \
setenv -g SSH_AUTH_SOCK /tmp/ssh_auth_sock_tmux; \
run-shell \"ln -sf $(find /tmp/ssh-* -type s -readable | head -n 1) /tmp/ssh_auth_sock_tmux\" \
"
# Drag windows on the status bar
bind-key -n MouseDrag1Status swap-window -t=
EndOfFile
# htop Configurations
mkdir -p $HOME/.config/htop
cat > $HOME/.config/htop/htoprc <<- EndOfFile
# Beware! This file is rewritten by htop when settings are changed in the interface.
# The parser is also very primitive, and not human-friendly.
fields=0 48 17 18 38 39 40 2 46 47 49 1
sort_key=46
sort_direction=1
hide_threads=0
hide_kernel_threads=1
hide_userland_threads=1
shadow_other_users=0
show_thread_names=0
show_program_path=1
highlight_base_name=0
highlight_megabytes=1
highlight_threads=1
tree_view=0
header_margin=1
detailed_cpu_time=0
cpu_count_from_zero=0
update_process_names=0
account_guest_in_cpu_meter=0
color_scheme=0
delay=15
left_meters=CPU Memory Swap
left_meter_modes=1 1 1
right_meters=Tasks LoadAverage Uptime
right_meter_modes=2 2 2
EndOfFile