Repository: LargeWorldModel/LWM Branch: main Commit: f45d2b70bda2 Files: 26 Total size: 244.6 KB Directory structure: gitextract_q2zwva4c/ ├── .gitignore ├── LICENSE ├── README.md ├── docs/ │ ├── data.md │ └── sharding.md ├── gpu_requirements.txt ├── lwm/ │ ├── __init__.py │ ├── data.py │ ├── llama.py │ ├── train.py │ ├── vision_chat.py │ ├── vision_generation.py │ ├── vision_llama.py │ └── vqgan.py ├── scripts/ │ ├── create_needle_data.py │ ├── eval_needle.py │ ├── eval_needle_multi.py │ ├── run_eval_needle.sh │ ├── run_eval_needle_multi.sh │ ├── run_sample_image.sh │ ├── run_sample_video.sh │ ├── run_train_text.sh │ ├── run_train_vision_text.sh │ ├── run_vision_chat.sh │ └── sample_pyt.py └── tpu_requirements.sh ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # local jobs/ local/ .vscode/ data/ *.model *.npy *.jsonl *.pkl *.json __pycache__/ ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # Large World Model (LWM) [[Project]](https://largeworldmodel.github.io/) [[Paper]](https://arxiv.org/abs/2402.08268) [[Models]](https://huggingface.co/LargeWorldModel) **Large World Model (LWM)** is a general-purpose large-context multimodal autoregressive model. It is trained on a large dataset of diverse long videos and books using RingAttention, and can perform language, image, and video understanding and generation. ## Approach
Current language models fall short in understanding aspects of the world not easily described in words, and struggle with complex, long-form tasks. Video sequences offer valuable temporal information absent in language and static images, making them attractive for joint modeling with language. Such models could develop a understanding of both human textual knowledge and the physical world, enabling broader AI capabilities for assisting humans. However, learning from millions of tokens of video and language sequences poses challenges due to memory constraints, computational complexity, and limited datasets. To address these challenges, we curate a large dataset of diverse videos and books, utilize the RingAttention technique to scalably train on long sequences, and gradually increase context size from 4K to 1M tokens. This paper makes the following contributions: (a) Largest context size neural network: We train one of the largest context size transformers on long video and language sequences, setting new benchmarks in difficult retrieval tasks and long video understanding. (b) Solutions for overcoming vision-language training challenges, including using masked sequence packing for mixing different sequence lengths, loss weighting to balance language and vision, and model-generated QA dataset for long sequence chat. (c) A highly-optimized implementation with RingAttention, masked sequence packing, and other key features for training on millions-length multimodal sequences. (d) Fully open-sourced a family of 7B parameter models capable of processing long text documents (LWM-Text, LWM-Text-Chat) and videos (LWM, LWM-Chat) of over 1M tokens. This work paves the way for training on massive datasets of long video and language to develop understanding of both human knowledge and the multimodal world, and broader capabilities. ## LWM Capabilities

LWM can retrieval facts across 1M context with high accuracy.


LWM can answer questions over 1 hour YouTube video.


LWM can chat with images.


LWM can generate videos and images from text.

## Setup This codebase is supported on Ubuntu and has not been tested on Windows or macOS. We recommend using TPUs for training and inference, although it is also possible to use GPUs. On TPU, the code is highly optimized with Jax's Pallas and can achieve high MFUs with RingAttention at very large context sizes. On GPU, the code is based on XLA and is not as optimized as it is for TPU. Install the requirements with: ``` conda create -n lwm python=3.10 conda activate lwm pip install -r gpu_requirements.txt ``` or set up TPU VM with: ``` sh tpu_requirements.sh ``` ## Available models There are language-only and video-language versions, offering context sizes from 32K, to 128K, 256K and 1M tokens. The vision-language models are available only in Jax, and the language-only models are available in both PyTorch and Jax. Below are the names of the available models and their corresponding context sizes and capabilities: | Model Name | Context Size | Language or Vision-Language | Chat or Base | URL | |--------------------|--------------|-----------------------------|--------------|----------------------------------------------------------------------------------------------------------------------------------------------| | LWM-Text-Chat-128K | 128K | Language | Chat | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-128K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-128K-Jax)] | | LWM-Text-Chat-256K | 256K | Language | Chat | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-256K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-256K-Jax)] | | LWM-Text-Chat-512K | 512K | Language | Chat | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-512K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-512K-Jax)] | | LWM-Text-Chat-1M | 1M | Language | Chat | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-1M)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-1M-Jax)] | | LWM-Text-128K | 128K | Language | Base | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-128K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-128K-Jax)] | | LWM-Text-256K | 256K | Language | Base | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-256K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-256K-Jax)] | | LWM-Text-512K | 512K | Language | Base | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-512K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-512K-Jax)] | | LWM-Text-1M | 1M | Language | Base | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-1M)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-1M-Jax)] | | LWM-Chat-32K | 32K | Vision-Language | Chat | [[Jax](https://huggingface.co/LargeWorldModel/LWM-32K-Jax)] | | LWM-Chat-128K | 128K | Vision-Language | Chat | [[Jax](https://huggingface.co/LargeWorldModel/LWM-128K-Jax)] | | LWM-Chat-1M | 1M | Vision-Language | Chat | [[Jax](https://huggingface.co/LargeWorldModel/LWM-1M-Jax)] | ## Code structure Use `scan_query_chunk_size` and `scan_key_chunk_size` to control the block size in blockwise compute of the self-attention. Use `scan_mlp_chunk_size` to control the block size in blockwise compute of the feedforward network. Use `scan_attention=True` and `scan_mlp=True` to enable/disable blockwise compute in the self-attention and feed-forward network. You can use `mesh_dim=dp, fsdp, tp, sp` to control the degree of parallelism and RingAttention. It is a string of 4 integers separated by commas, representing the number of data parallelism, fully sharded data parallelism, tensor parallelism, and sequence parallelism. For example, `mesh_dim='1,64,4,1'` means 1 data parallelism, 64 fully sharded data parallelism, 4 tensor parallelism, and 1 sequence parallelism. `mesh_dim='1,1,4,64'` means 1 data parallelism, 1 fully sharded data parallelism, 4 tensor parallelism, and 64 sequence parallelism for RingAttention. ## Running Jax Models In this section, we provide instructions on how to run each of the provided scripts. For each script, you may need to fill in your own paths and values in the variables described in the beginning of each script. To run each of the following scripts, use `bash .sh`: - Language model training: `bash scripts/run_train_text.sh` - Vision-Language model training: `bash scripts/run_train_vision_text.sh` - Single Needle Evals (Language Model): `bash scripts/run_eval_needle.sh` - Multi Needle Evals (Language Model): `bash scripts/run_eval_needle_multi.sh` - Sampling images (Vision-Language Model): `bash scripts/run_sample_image.sh` - Sampling videos (Vision-LanguageModel): `bash scripts/run_sample_video.sh` - Image / Video understanding (Vision-Language Model): `bash scripts/run_vision_chat.sh` By default the `mesh_dim` argument puts all devices on `tp` (tensor parallelism). For longer sequences, you may want to include `sp`, which is the last dimension in the `mesh_dim`. When running needle evals, you may need to adjust the `theta` and `max_sequence_length` arguments in the scripts depending on the model. Below shows the correct values for each model. | | LWM-Text-128K / LWM-Text-Chat-128K | LWM-Text-256K / LWM-Text-Chat-256K | LWM-Text-512K / LWM-Text-Chat-512K | LWM-Text-1M / LWM-Text-Chat-1M | |---------------------|:-----------------------------------:|:-----------------------------------:|:----------------------------------:|:------------------------------:| | theta | 10000000 | 10000000 | 25000000 | 50000000 | | max_sequence_length | 131072 | 262144 | 524288 | 1048576 | An example of filling out a script (`run_sample_video.sh`) is as follows ```bash #! /bin/bash export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )" cd $PROJECT_DIR export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR" export llama_tokenizer_path="LargeWorldModel/LWM-Text-1M" export vqgan_checkpoint="/path/to/ckpt/folder/vqgan" export lwm_checkpoint="params::/path/to/ckpt/folder/params" python3 -u -m lwm.vision_generation \ --prompt='Fireworks over the city' \ --output_file='fireworks.mp4' \ --temperature_image=1.0 \ --temperature_video=1.0 \ --top_k_image=8192 \ --top_k_video=1000 \ --cfg_scale_image=5.0 \ --cfg_scale_video=1.0 \ --vqgan_checkpoint="$vqgan_checkpoint" \ --n_frames=8 \ --mesh_dim='!1,1,-1,1' \ --dtype='fp32' \ --load_llama_config='7b' \ --update_llama_config="dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \ --load_checkpoint="$lwm_checkpoint" \ --tokenizer="$llama_tokenizer_path" read ``` ## Needle Haystack Data Run `python scripts/create_needle_data.py` ## Running PyTorch Models Only text and text chat models are currently supported for PyTorch inference. PyTorch models can be loaded as Hugging Face `LlamaForCausalLM` models. Run `python scripts/sample_pyt.py` to sample. You may need to separately install `torch`. ## Documentation For more details on the codebase, please refer to the [data.md](docs/data.md) and [sharding.md](docs/sharding.md). The [data.md](docs/data.md) provides details on the data processing and the [sharding.md](docs/sharding.md) provides details on the sharding and parallelism. ## If you have issues This is based on the [codebase](https://github.com/haoliuhl/ringattention) of RingAttention, with the necessary features for vision-language training. The training and inference have been tested on both TPUv3 and TPUv4. If you encounter bugs, please open a GitHub issue! ## Citation If you use this codebase, or otherwise found our work valuable, please cite: ``` @article{liu2023world, title={World Model on Million-Length Video and Language with RingAttention}, author={Liu, Hao and Yan, Wilson and Zaharia, Matei and Abbeel, Pieter}, journal={arXiv preprint}, year={2024}, } @article{liu2023ring, title={Ring Attention with Blockwise Transformers for Near-Infinite Context}, author={Liu, Hao and Zaharia, Matei and Abbeel, Pieter}, journal={International Conference on Learning Representations}, year={2024} } @article{liu2023blockwise, title={Blockwise Parallel Transformer for Large Context Models}, author={Liu, Hao and Abbeel, Pieter}, journal={Advances in neural information processing systems}, year={2023} } ``` ## License LWM's code is released under the Apache 2.0 License. See [LICENSE](https://github.com/LargeWorldModel/lwm/blob/main/LICENSE) for further details. The models are released under the Llama-2 license. ================================================ FILE: docs/data.md ================================================ # Data We support two types of datasets: Huggingface dataset and JSON dataset. The dataset modules are implemented in the [data.py](/lwm/data.py) file. Configuration requires dataset type, text processor, and dataset specific configurations. The following is an example of using the Huggingface dataset to train a model: ```bash python -m lwm.train \ --train_dataset.text_processor.fields='text' \ --train_dataset.type='huggingface' \ --train_dataset.huggingface_dataset.path='openwebtext' ``` In this example, we select the Huggingface dataset by specifying the `type` of `train_dataset` to be `huggingface`. We then specify the path to the dataset, which is `c4` in this case. The examples loaded from the dataset will be processed by a TextProcessor, which is configured by the `text_processor` field. The following options are supported for the dataset module: * `type`: The type of the dataset. Supported values are `huggingface` and `json`. * `text_processor`: The configuration of the TextProcessor used to process the loaded examples. * `huggingface_dataset`: The configuration of the Huggingface dataset. * `json_dataset`: The configuration of the JSON dataset. For huggingface dataset, we expect loading examples from a Huggingface dataset. * `path`: The path to the dataset. Same as the `path` argument in `datasets.load_dataset`. * `name`: Name of the dataset within the path. Same as the `name` argument in `datasets.load_dataset`. * `split`: The split of the dataset. Same as the `split` argument in `datasets.load_dataset`. * `streaming`: Whether to stream the dataset. Same as the `streaming` argument in `datasets.load_dataset`. * `seq_length`: The length of the tokenized sequence. * `batch_size`: Batch size of tokenized examples. For JSON dataset, we expect loading examples from a text file, where each line where each line represents a JSON encoded dictionary. Here are the configurable options for JSON dataset: * `path`: Path to the text file. The file can be located on the local file system or on Google Cloud Storage bucket. * `seq_length`: The length of the tokenized sequence. * `batch_size`: Batch size of tokenized examples. * `start_seek_loc`: The starting seek location in the file. This is useful when you want to resume training from a particular location in the file. * `index_at_start`: The counting index at the beginning. This is useful to keep the index count when resuming from a particular location in the file. Note that this is only for logging purpose, and does not affect the actual examples starting from. To start from a different example in the dataset, you should use the `start_seek_loc` option. * `tokenizer_processes`: The number of processes to use for tokenization. Tokenization is done in parallel to speed up the loading process. A JSON dataset can be generated as follows: ```python from datasets import load_dataset import json from multiprocessing import Pool, cpu_count dataset = load_dataset("openwebtext") split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) split_dataset['val'] = split_dataset.pop('test') def save_split(split): with open(f"openwebtext_{split}.jsonl", "w") as f: for example in split_dataset[split]: json.dump({"text": example["text"]}, f) f.write("\n") with Pool(cpu_count()) as p: p.map(save_split, ["train", "val"]) ``` This generates two files, `openwebtext_train.jsonl` and `openwebtext_val.jsonl`, which can be used as the dataset for training. Both files contain a single field, `text`, which is the text to be processed by the model. For example, to train a model using the `openwebtext_train.jsonl` file, you can use the following command: ```bash python -m lwm.train \ --train_dataset.text_processor.fields='text' \ --train_dataset.type='json' \ --train_dataset.json_dataset.path='openwebtext_train.jsonl' \ ``` For vision-langauge training, we recommend using the JSON dataset, as it allows you to pre-tokenize vision (images and videos), and load the tokenized vision along with the text. Each loaded example is a dictionary, which will be processed by a TextProcessor ## Text Processor We use the `TextProcessor` class to process the loaded examples from a dataset. This allows us to flexibly process various formats. Each input example is a dictionary of multiple text fields. The TextProcessor will process text fields according to its configurations, and return the final tokens. Here are the configurable options for TextProcessor: * `fields`: A comma separated list of text fields to process. * `fields_from_example`: Whether to use the keys of the input example as the text fields to process. If this option is set, the `fields` argument will be ignored. * `subfield_separator`: The text separator to use when concatenating subfields of a texts. * `add_eos_token`: Whether to add an EOS token to the end of the text. * `prepend_text`: The text to prepended to the beginning. The most important configuration for TextProcessor is the `fields` argument. It is a comma separated list of text fields to process. Each field consists of one or more subfields, which are separated by a `+`. Each subfield represent a key used to extract the text from the input example dictionary. The TextProcessor joins the extracted subfields of texts with the `subfield_separator` in the text level and then tokenize the joined text. Finally, the TextProcessor will concatenate the tokenized text fields at the token level, and add the EOS token if specified. Other than the keys in the input example, you can also use the following special keys to indicate a special token for a text field: * `<|bos|>`: Beginning of sentence token. * `<|eos|>`: End of sentence token. For each text field, you can encapulate the subfields with `[]` to specify that the loss should not be computed for this field. Doing so will make the loss masks to be 0 for this field. This is useful when you want to use the text field as a prompt for the model. To give a concrete example, if the input example looks like this: ```python { 'vision': 'VQ tokens of a picture of a cat', 'question': 'what is the color of the cat', 'answer': 'the color of the cat is yellow', } ``` To use the `vision` and `question` as the input text, and `answer` as the target, we can specify the following configuration for the `fields` argument: ``` [vision+question],answer ``` The `vision+question` indicates that the `vision` and `question` should be joined togather with the `subfield_separator`, which is a space by default. The `[]` indicates that the loss should not be computed for this field. The `answer` field is then concatenated at the token level, where the loss will be computed. ================================================ FILE: docs/sharding.md ================================================ # Sharding Sharding is a technique to partition the computation and the model across multiple accelerators. This codebase supports flexible model and data parallelism for training and serving. The sharding can be specified using the `mesh_dim` command line argument. The `mesh_dim` is a comma separated list of integers representing the parallelism mesh axis dimensions. One of the axis dimensions can be `-1`, which means that the axis dimension will be inferred based on the total number of accelerators. The first axis of the mesh is used for data parallelism (`dp`), the second axis used for fully sharded data parallelism (`fsdp`), the third axis is used for tensor parallelism (`tp`), the last axis is used for sequence parallelism (required for ring attention) (`sp`). For example, `mesh_dim='1,64,4,1'` means 1 data parallelism, 64 fully sharded data parallelism, 4 tensor parallelism, and 1 sequence parallelism. While `mesh_dim='1,1,4,64'` means 1 data parallelism, 1 fully sharded data parallelism, 4 tensor parallelism, and 64 sequence parallelism for RingAttention. Your total number of accelerators should be equal to the product of the mesh dimensions. For example, `mesh_dim='1,64,4,1'` requires 64 accelerators, and `mesh_dim='1,1,4,64'` requires 256 accelerators. In general, you want to use the largest possible mesh dimension for `fsdp`. Such as `mesh_dim='1,64,1,1'` is preferred over `mesh_dim='8,8,1,1'` because the former has larger `fsdp` dimensions, which allows overlapping of computation and communication, and thus better performance. The batch size (number of sequences per batch) should be larger than or equal to `fsdp * dp`. If you think the batch size is too large, you can allocate more accelerators to `tp` and `sp` to increase the model size and sequence length. Using `sp` to control the sequence parallelism is required to use RingAttention. `sp=8` means sharding sequence length by 8, and `sp=1` means no sharding. For models that use standard attention, you can set `sp=1` and use `dp`, `fsdp`, and `tp` to control the parallelism. ================================================ FILE: gpu_requirements.txt ================================================ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jax[cuda12]==0.4.29 flax==0.8.4 optax==0.2.2 --extra-index-url https://download.pytorch.org/whl/cpu torch==2.0.0 transformers==4.40.0 ringattention @ git+https://github.com/haoliuhl/ringattention.git datasets einops tqdm ml_collections wandb gcsfs requests typing-extensions sentencepiece tux @ git+https://github.com/haoliuhl/tux.git Pillow ffmpeg-python ipdb imageio[ffmpeg] opencv-python decord ffmpeg-python h5py psutil ================================================ FILE: lwm/__init__.py ================================================ ================================================ FILE: lwm/data.py ================================================ import time import random from functools import partial import json from multiprocessing import Pool from tux import open_file from ml_collections import ConfigDict import numpy as np import jax from jax.experimental.multihost_utils import host_local_array_to_global_array from jax.sharding import PartitionSpec as PS from datasets import load_dataset class DatasetFactory(object): """ Datset builder class. """ @staticmethod def get_default_config(updates=None): config = ConfigDict() config.type = 'huggingface' config.text_processor = TextProcessor.get_default_config() config.huggingface_dataset = HuggingfaceDataset.get_default_config() config.json_dataset = JsonDataset.get_default_config() config.vision_text_processor = VisionTextProcessor.get_default_config() config.json_vision_dataset = JsonVisionDataset.get_default_config() if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) return config @classmethod def load_dataset(cls, config, tokenizer, **kwargs): config = cls.get_default_config(config) if config.type == 'huggingface': text_processor = TextProcessor(config.text_processor, tokenizer) return HuggingfaceDataset( config.huggingface_dataset, tokenizer, text_processor, **kwargs ) elif config.type == 'json': text_processor = TextProcessor(config.text_processor, tokenizer) return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs) elif config.type == 'json_vision': vision_text_processor = VisionTextProcessor(config.vision_text_processor, tokenizer) return JsonVisionDataset(config.json_vision_dataset, tokenizer, vision_text_processor, **kwargs) else: raise ValueError(f'Unknown dataset type: {config.type}') def __init__(self): raise ValueError('DatasetFactory is a static class and should not be instantiated.') class TextProcessor(object): """ Example processor that converts a dictionary of texts into tokens. """ @staticmethod def get_default_config(updates=None): config = ConfigDict() config.fields_from_example = '' config.fields = '' config.subfield_separator = ' ' config.add_bos_token = True config.add_eos_token = True config.prepend_text = '' if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) return config def __init__(self, config, tokenizer): self.config = self.get_default_config(config) assert self.config.fields != '' or self.config.fields_from_example != '', ( 'Either fields or fields_from_example must be specified.' ) self.tokenizer = tokenizer def __call__(self, example, has_aux=False, add_bos_token=True, add_eos_token=True): if has_aux: example, *aux = example else: aux = tuple() token_buffer = [] loss_mask_buffer = [] if add_bos_token and self.config.add_bos_token: token_buffer.append(self.tokenizer.bos_token_id) loss_mask_buffer.append(0.0) if self.config.fields_from_example != '': fields = example[self.config.fields_from_example].split(',') else: fields = self.config.fields.split(',') for i, field in enumerate(fields): if field.startswith('[') and field.endswith(']'): # No loss for this field. field = field[1:-1] mask = 0.0 else: mask = 1.0 if field == '<|bos|>': token_buffer.append(self.tokenizer.bos_token_id) loss_mask_buffer.append(mask) elif field == '<|eos|>': token_buffer.append(self.tokenizer.eos_token_id) loss_mask_buffer.append(mask) else: subfields = field.split('+') text = self.config.subfield_separator.join( [example[subfield] for subfield in subfields] ) if i == 0: text = self.config.prepend_text + text tokens = self.tokenizer.encode(text, add_special_tokens=False) token_buffer.extend(tokens) loss_mask_buffer.extend([mask for _ in range(len(tokens))]) if add_eos_token and self.config.add_eos_token: token_buffer.append(self.tokenizer.eos_token_id) loss_mask_buffer.append(1.0) return token_buffer, loss_mask_buffer, *aux class VisionTextProcessor(object): @staticmethod def get_default_config(updates=None): config = ConfigDict() config.fields_from_example = '' config.subfield_separator = ' ' config.add_bos_token = True config.add_eos_token = True config.prepend_text = '' config.fields_index = -1 config.eof_token = 8192 # denotes end of each frame for video generation config.eov_token = 8193 # denotes end of vision generation config.n_tokens_per_frame = 256 # 16 x 16 VQ codes config.max_n_frames = -1 if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) return config def __init__(self, config, tokenizer): self.config = self.get_default_config(config) assert self.config.fields_from_example != '', ( 'fields_from_example must be specified.' ) self.tokenizer = tokenizer self.vision_start = tokenizer.encode('') self.vision_end = tokenizer.encode('') def __call__(self, example, has_aux=False, add_bos_token=True, add_eos_token=True): if has_aux: example, *aux = example else: aux = tuple() rand_state = random.Random(aux[-1]) # makes augmentations deterministic by line number token_buffer = [] loss_mask_buffer = [] vision_mask = [] fields = example[self.config.fields_from_example] if isinstance(fields, (tuple, list)): if self.config.fields_index >= 0: fields = fields[self.config.fields_index] else: # seed based on line number fields = rand_state.choice(fields) fields = fields.split(',') if add_bos_token and self.config.add_bos_token: token_buffer.append(self.tokenizer.bos_token_id) loss_mask_buffer.append(0.0) vision_mask.append(False) for i, field in enumerate(fields): if field.startswith('[') and field.endswith(']'): # No loss for this field. field = field[1:-1] mask = 0.0 else: mask = 1.0 if field == '<|bos|>': token_buffer.append(self.tokenizer.bos_token_id) loss_mask_buffer.append(mask) vision_mask.append(False) elif field == '<|eos|>': token_buffer.append(self.tokenizer.eos_token_id) loss_mask_buffer.append(mask) vision_mask.append(False) elif 'vision' in field: vision_tokens = example[field] n_frames = int(len(vision_tokens) / self.config.n_tokens_per_frame) if self.config.max_n_frames > 0 and n_frames > self.config.max_n_frames: # uniformly select idxs = np.linspace(0, n_frames - 1, self.config.max_n_frames).astype(int) new_vision_tokens = [] for idx in idxs: new_vision_tokens.extend(vision_tokens[idx * self.config.n_tokens_per_frame:(idx + 1) * self.config.n_tokens_per_frame]) vision_tokens = new_vision_tokens n_frames = self.config.max_n_frames assert int(len(vision_tokens) / self.config.n_tokens_per_frame) == n_frames, (int(len(vision_tokens) / self.config.n_tokens_per_frame), n_frames) assert n_frames > 0, len(vision_tokens) tokens = list(self.vision_start) for j in range(n_frames): tokens.extend(vision_tokens[j*self.config.n_tokens_per_frame:(j+1)*self.config.n_tokens_per_frame]) if j == n_frames - 1: # last frame tokens.append(self.config.eov_token) else: tokens.append(self.config.eof_token) tokens.extend(self.vision_end) token_buffer.extend(tokens) loss_mask_buffer.extend([mask for _ in range(len(tokens))]) vision_mask.extend([False] * len(self.vision_start)) vision_mask.extend([True] * (self.config.n_tokens_per_frame * n_frames + n_frames)) # include extra eof/eov token at the end of each frame vision_mask.extend([False] * len(self.vision_end)) else: subfields = field.split('+') text = self.config.subfield_separator.join( [example[subfield] for subfield in subfields] ) if i == 0: text = self.config.prepend_text + text tokens = self.tokenizer.encode(text) token_buffer.extend(tokens) loss_mask_buffer.extend([mask for _ in range(len(tokens))]) vision_mask.extend([False] * len(tokens)) if add_eos_token and self.config.add_eos_token: token_buffer.append(self.tokenizer.eos_token_id) loss_mask_buffer.append(1.0) vision_mask.append(False) assert len(token_buffer) == len(loss_mask_buffer) == len(vision_mask), (len(token_buffer), len(loss_mask_buffer), len(vision_mask)) keep = True return token_buffer, loss_mask_buffer, vision_mask, keep, *aux class HuggingfaceDataset(object): """ Huggingface dataset, where the dataset is loaded using the huggingface datasets.load_dataset() function. """ @staticmethod def get_default_config(updates=None): config = ConfigDict() config.path = 'c4' config.name = 'en' config.split = 'train' config.streaming = False config.seq_length = 1024 config.batch_size = 8 config.always_start_with_bos = False if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) return config def __init__(self, config, tokenizer, text_processor): self.config = self.get_default_config(config) name = self.config.name if self.config.name != '' else None split = self.config.split if self.config.split != '' else None self._tokenizer = tokenizer self._text_processor = text_processor self._dataset = load_dataset( self.config.path, name, split=split, streaming=self.config.streaming ) def __iter__(self): chunk_size = self.config.batch_size * self.config.seq_length total_tokens = 0 while True: token_buffer = [] loss_mask_buffer = [] for index, example in enumerate(self._dataset): tokens, loss_masks = self.text_processor(example) token_buffer.extend(tokens) loss_mask_buffer.extend(loss_masks) while len(token_buffer) > chunk_size + 1: total_tokens += chunk_size metrics = { 'dataset_example_index': index, 'dataset_total_tokens': total_tokens, } batch = { 'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape( self.config.batch_size, -1 ), 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape( self.config.batch_size, -1 ), 'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape( self.config.batch_size, -1 ), } if self.config.always_start_with_bos: batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id yield batch, metrics token_buffer = token_buffer[chunk_size:] loss_mask_buffer = loss_mask_buffer[chunk_size:] def get_state_dict(self): return dict(config=self.config) def load_state_dict(self, state_dict): if 'config' in state_dict: self.config.update(ConfigDict(state_dict['config'])) @property def seq_length(self): return self.config.seq_length @property def tokenizer(self): return self._tokenizer @property def text_processor(self): return self._text_processor @property def dataset(self): return self._dataset @property def vocab_size(self): return len(self._tokenizer) class JsonDataset(object): """ JSON dataset, where each line of the data file contains a JSON dictionary with text fields. """ @staticmethod def get_default_config(updates=None): config = ConfigDict() config.path = '' config.seq_length = 1024 config.batch_size = 8 config.always_start_with_bos = False config.start_seek_loc = 0 config.example_index_at_start = 0 config.tokens_count_at_start = 0 config.tokenizer_processes = 1 config.tokenizer_parallel_chunk_size = 32 config.tokenizer_parallel_batch_size = 1024 config.throughput_average_window_size = 200 config.pad = False config.use_data_sharded_loader = True config.return_local_batch = False if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) return config def __init__(self, config, tokenizer, text_processor, node_info): self.config = self.get_default_config(config) assert self.config.path != '' self._tokenizer = tokenizer self._text_processor = text_processor self._node_info = node_info self._index = self.config.example_index_at_start self._file_loc = self.config.start_seek_loc self._total_tokens = self.config.tokens_count_at_start def parse_json(self, line): if not line or line == '\n': return None try: data = json.loads(line) except json.decoder.JSONDecodeError: print(f'Error parsing json line:\n{line}') return None return data def json_iterator(self): index, file_loc = self._index, self._file_loc with open_file(self.config.path, 'r') as fin: fin.seek(file_loc) while True: line = fin.readline() file_loc = fin.tell() if not line: # Reached EOF index = 0 fin.seek(0) continue data = self.parse_json(line) if data is not None and (not self.config.use_data_sharded_loader or index % self._node_info['dp_node_size'] == self._node_info['dp_node_rank']): # JSON parsing succeeded yield data, file_loc, index index += 1 def batched(self, iterator, batch_size): batch = [] for example in iterator: batch.append(example) if len(batch) == batch_size: yield batch batch = [] if len(batch) > 0: yield batch def parallel_example_iterator(self): if self.config.tokenizer_processes == 1: for example, loc, index in self.json_iterator(): self._file_loc = loc self._index = index yield self.text_processor((example, loc, index), has_aux=True) else: process_pool = Pool(self.config.tokenizer_processes) batched_iterator = self.batched( self.json_iterator(), self.config.tokenizer_parallel_batch_size ) with process_pool as pool: map_fn = partial(self.text_processor, has_aux=True) next_batch = pool.map_async( map_fn, next(batched_iterator), chunksize=self.config.tokenizer_parallel_chunk_size ) while True: current_batch = next_batch next_batch = pool.map_async( map_fn, next(batched_iterator), chunksize=self.config.tokenizer_parallel_chunk_size ) for example in current_batch.get(): yield example def __iter__(self): global_chunk_size = self.config.batch_size * self.config.seq_length if self.config.use_data_sharded_loader: local_batch_size = self.config.batch_size // self._node_info['dp_node_size'] else: local_batch_size = self.config.batch_size chunk_size = local_batch_size * self.config.seq_length token_buffer = [] loss_mask_buffer = [] last_time = 0.0 step_times = [] start_time = time.time() start_tokens = self._total_tokens for tokens, loss_masks, loc, index in self.parallel_example_iterator(): self._file_loc = loc self._index = index if self.config.pad: tokens = tokens[:self.config.seq_length + 1] tokens.extend([self._tokenizer.bos_token_id] * (self.config.seq_length + 1 - len(tokens))) loss_masks = loss_masks[:self.config.seq_length + 1] loss_masks.extend([0.0] * (self.config.seq_length + 1 - len(loss_masks))) token_buffer.extend(tokens) loss_mask_buffer.extend(loss_masks) while len(token_buffer) > chunk_size + 1: self._total_tokens += global_chunk_size step_times.append(time.time() - last_time) last_time = time.time() if len(step_times) > self.config.throughput_average_window_size: step_times = step_times[-self.config.throughput_average_window_size:] average_throughput = global_chunk_size / np.mean(step_times) accumulated_throughput = ( (self._total_tokens - start_tokens) / (time.time() - start_time) ) metrics = { 'dataset_file_loc': loc, 'dataset_example_index': index, 'dataset_total_tokens': self._total_tokens, 'dataset_accumulated_tps': accumulated_throughput, 'dataset_average_tps': average_throughput, } batch = { 'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape( local_batch_size, -1 ), 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape( local_batch_size, -1 ), 'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape( local_batch_size, -1 ), } batch.update({ 'input_vision_masks': np.zeros(batch['input_tokens'].shape, dtype=bool), 'target_vision_masks': np.zeros(batch['input_tokens'].shape, dtype=bool), }) if self.config.always_start_with_bos: batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id if self.config.use_data_sharded_loader and not self.config.return_local_batch: mesh = self._node_info['mesh'] sp_nodes_size = max(1, mesh.shape['sp'] // jax.local_device_count()) sp_nodes_rank = jax.process_index() % sp_nodes_size assert self.config.seq_length % sp_nodes_size == 0, (self.config.seq_len, sp_nodes_size) seq_chunk_size = self.config.seq_length // sp_nodes_size batch = {k: v[:, sp_nodes_rank*seq_chunk_size:(sp_nodes_rank+1)*seq_chunk_size] for k, v in batch.items()} batch = host_local_array_to_global_array(batch, self._node_info['mesh'], PS(('dp', 'fsdp'), 'sp')) yield batch, metrics if self.config.pad: token_buffer, loss_mask_buffer = [], [] else: token_buffer = token_buffer[chunk_size:] loss_mask_buffer = loss_mask_buffer[chunk_size:] def _make_callback(self, v): return lambda index: v[index] def get_state_dict(self): return dict( config=self.config, index=self._index, file_loc=self._file_loc, total_tokens=self._total_tokens, ) def load_state_dict(self, state_dict): if 'config' in state_dict: self.config.update(ConfigDict(state_dict['config'])) self._index = state_dict.get('index', self.config.example_index_at_start) self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc) self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start) @property def seq_length(self): return self.config.seq_length @property def tokenizer(self): return self._tokenizer @property def text_processor(self): return self._text_processor @property def vocab_size(self): return len(self.tokenizer) class JsonVisionDataset(object): @staticmethod def get_default_config(updates=None): config = ConfigDict() config.path = '' config.seq_length = 384 config.batch_size = 4 config.always_start_with_bos = False config.start_seek_loc = 0 config.example_index_at_start = 0 config.tokens_count_at_start = 0 config.tokenizer_processes = 1 config.tokenizer_parallel_chunk_size = 32 config.tokenizer_parallel_batch_size = 1024 config.throughput_average_window_size = 200 config.use_data_sharded_loader = True config.return_local_batch = False config.mode = 'pad' if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) return config def __init__(self, config, tokenizer, text_processor, node_info): self.config = self.get_default_config(config) assert self.config.path != '' self._node_info = node_info self._tokenizer = tokenizer self._text_processor = text_processor self._index = self.config.example_index_at_start self._file_loc = self.config.start_seek_loc self._total_tokens = 0 def parse_json(self, line): if not line or line == '\n': return None try: data = json.loads(line) except json.decoder.JSONDecodeError: print(f'Error parsing json line:\n{line}') return None return data def json_iterator(self): index, file_loc = self._index, self._file_loc with open_file(self.config.path, 'r', block_size=50 * 2 ** 20) as fin: fin.seek(file_loc) while True: line = fin.readline() file_loc = fin.tell() if not line: # Reached EOF index = 0 fin.seek(0) continue if not self.config.use_data_sharded_loader or index % self._node_info['dp_node_size'] == self._node_info['dp_node_rank']: data = self.parse_json(line) if data is not None: # JSON parsing succeeded yield data, file_loc, index index += 1 def batched(self, iterator, batch_size): batch = [] for example in iterator: batch.append(example) if len(batch) == batch_size: yield batch batch = [] if len(batch) > 0: yield batch def parallel_example_iterator(self): if self.config.tokenizer_processes == 1: for example, loc, index in self.json_iterator(): self._file_loc = loc self._index = index yield self.text_processor((example, loc, index), has_aux=True) else: process_pool = Pool(self.config.tokenizer_processes) batched_iterator = self.batched( self.json_iterator(), self.config.tokenizer_parallel_batch_size ) with process_pool as pool: map_fn = partial(self.text_processor, has_aux=True) next_batch = pool.map_async( map_fn, next(batched_iterator), chunksize=self.config.tokenizer_parallel_chunk_size ) while True: current_batch = next_batch next_batch = pool.map_async( map_fn, next(batched_iterator), chunksize=self.config.tokenizer_parallel_chunk_size ) for example in current_batch.get(): yield example def __iter__(self): if self.config.mode == 'pad': fn = self._iter_pad elif self.config.mode == 'no_pad': fn = self._iter_no_pad else: raise ValueError(f'Unknown mode: {self.config.mode}') return fn() def _iter_pad(self): chunk_size = self.config.batch_size * self.config.seq_length if self.config.use_data_sharded_loader: local_batch_size = self.config.batch_size // self._node_info['dp_node_size'] else: local_batch_size = self.config.batch_size last_time = 0.0 buffer = [] step_times = [] start_time = time.time() start_tokens = self._total_tokens for tokens, loss_masks, vision_masks, keep, loc, index in self.parallel_example_iterator(): if not keep: continue self._file_loc = loc self._index = index buffer.append((tokens, loss_masks, vision_masks)) while len(buffer) >= local_batch_size: self._total_tokens += chunk_size step_times.append(time.time() - last_time) last_time = time.time() if len(step_times) > self.config.throughput_average_window_size: step_times = step_times[-self.config.throughput_average_window_size:] average_throughput = chunk_size / np.mean(step_times) accumulated_throughput = ( (self._total_tokens - start_tokens) / (time.time() - start_time) ) metrics = { 'dataset_file_loc': loc, 'dataset_example_index': index, 'dataset_total_tokens': self._total_tokens, 'dataset_accumulated_tps': accumulated_throughput, 'dataset_average_tps': average_throughput, } batch = { 'input_tokens': np.full( (local_batch_size, self.config.seq_length), self._tokenizer.bos_token_id, dtype=np.int32 ), 'target_tokens': np.full( (local_batch_size, self.config.seq_length), self._tokenizer.bos_token_id, dtype=np.int32 ), 'loss_masks': np.zeros( (local_batch_size, self.config.seq_length), dtype=np.float32 ), 'input_vision_masks': np.zeros( (local_batch_size, self.config.seq_length), dtype=bool ), 'target_vision_masks': np.zeros( (local_batch_size, self.config.seq_length), dtype=bool ) } for i in range(local_batch_size): tokens, loss_masks, vision_masks = buffer[i] if len(tokens) > self.config.seq_length: tokens = tokens[:self.config.seq_length + 1] loss_masks = loss_masks[1:self.config.seq_length + 1] vision_masks = vision_masks[:self.config.seq_length + 1] input_tokens, target_tokens = tokens[:-1], tokens[1:] input_vision_masks, target_vision_masks = vision_masks[:-1], vision_masks[1:] loss_masks = loss_masks[1:] batch['input_tokens'][i, :len(input_tokens)] = input_tokens batch['target_tokens'][i, :len(target_tokens)] = target_tokens batch['input_vision_masks'][i, :len(input_vision_masks)] = input_vision_masks batch['target_vision_masks'][i, :len(target_vision_masks)] = target_vision_masks batch['loss_masks'][i, :len(loss_masks)] = loss_masks if self.config.use_data_sharded_loader and not self.config.return_local_batch: mesh = self._node_info['mesh'] sp_nodes_size = max(1, mesh.shape['sp'] // jax.local_device_count()) sp_nodes_rank = jax.process_index() % sp_nodes_size assert self.config.seq_length % sp_nodes_size == 0, (self.config.seq_len, sp_nodes_size) seq_chunk_size = self.config.seq_length // sp_nodes_size batch = {k: v[:, sp_nodes_rank*seq_chunk_size:(sp_nodes_rank+1)*seq_chunk_size] for k, v in batch.items()} batch = host_local_array_to_global_array(batch, self._node_info['mesh'], PS(('dp', 'fsdp'), 'sp')) yield batch, metrics buffer = buffer[local_batch_size:] def _iter_no_pad(self): global_chunk_size = self.config.batch_size * self.config.seq_length if self.config.use_data_sharded_loader: local_batch_size = self.config.batch_size // self._node_info['dp_node_size'] else: local_batch_size = self.config.batch_size chunk_size = local_batch_size * self.config.seq_length token_buffer = [] loss_mask_buffer = [] vision_mask_buffer = [] last_time = 0.0 step_times = [] start_time = time.time() start_tokens = self._total_tokens for tokens, loss_masks, vision_masks, keep, loc, index in self.parallel_example_iterator(): if not keep: continue self._file_loc = loc self._index = index token_buffer.extend(tokens) loss_mask_buffer.extend(loss_masks) vision_mask_buffer.extend(vision_masks) while len(token_buffer) > chunk_size + 1: self._total_tokens += global_chunk_size step_times.append(time.time() - last_time) last_time = time.time() if len(step_times) > self.config.throughput_average_window_size: step_times = step_times[-self.config.throughput_average_window_size:] average_throughput = global_chunk_size / np.mean(step_times) accumulated_throughput = ( (self._total_tokens - start_tokens) / (time.time() - start_time) ) metrics = { 'dataset_file_loc': loc, 'dataset_example_index': index, 'dataset_total_tokens': self._total_tokens, 'dataset_accumulated_tps': accumulated_throughput, 'dataset_average_tps': average_throughput, } batch = { 'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape( local_batch_size, -1 ), 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape( local_batch_size, -1 ), 'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape( local_batch_size, -1 ), 'input_vision_masks': np.array(vision_mask_buffer[:chunk_size], dtype=bool).reshape( local_batch_size, -1 ), 'target_vision_masks': np.array(vision_mask_buffer[1:chunk_size + 1], dtype=bool).reshape( local_batch_size, -1 ), } if self.config.use_data_sharded_loader and not self.config.return_local_batch: mesh = self._node_info['mesh'] sp_nodes_size = max(1, mesh.shape['sp'] // jax.local_device_count()) sp_nodes_rank = jax.process_index() % sp_nodes_size assert self.config.seq_length % sp_nodes_size == 0, (self.config.seq_len, sp_nodes_size) seq_chunk_size = self.config.seq_length // sp_nodes_size batch = {k: v[:, sp_nodes_rank*seq_chunk_size:(sp_nodes_rank+1)*seq_chunk_size] for k, v in batch.items()} batch = host_local_array_to_global_array(batch, self._node_info['mesh'], PS(('dp', 'fsdp'), 'sp')) yield batch, metrics token_buffer = token_buffer[chunk_size:] loss_mask_buffer = loss_mask_buffer[chunk_size:] vision_mask_buffer = vision_mask_buffer[chunk_size:] def _make_callback(self, v): return lambda index: v[index] def get_state_dict(self): return dict( config=self.config, index=self._index, file_loc=self._file_loc, total_tokens=self._total_tokens, ) def load_state_dict(self, state_dict): if 'config' in state_dict: self.config.update(ConfigDict(state_dict['config'])) self._index = state_dict.get('index', self.config.example_index_at_start) self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc) self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start) @property def seq_length(self): return self.config.seq_length @property def tokenizer(self): return self._tokenizer @property def text_processor(self): return self._text_processor @property def vocab_size(self): return len(self._tokenizer) ================================================ FILE: lwm/llama.py ================================================ import os from shutil import copyfile from typing import Any, Dict, List, Optional, Tuple, Union import json import tempfile from functools import partial import numpy as np import jax import jax.numpy as jnp from jax import lax from jax.sharding import PartitionSpec as PS from jax.experimental.shard_map import shard_map import flax.linen as nn from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.traverse_util import flatten_dict, unflatten_dict from flax.linen import partitioning as nn_partitioning import sentencepiece as spm from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging from transformers.tokenization_utils import PreTrainedTokenizer from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput from transformers.modeling_flax_utils import FlaxPreTrainedModel from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from ml_collections import ConfigDict from tux import function_args_to_config, load_pickle, open_file, with_sharding_constraint, get_jax_mesh from ringattention import ringattention, blockwise_feedforward, ringattention_jax, ringattention_inference LLAMA_STANDARD_CONFIGS = { '200m': { 'vocab_size': 32000, 'hidden_size': 1024, 'intermediate_size': 2048, 'num_hidden_layers': 14, 'num_attention_heads': 8, 'max_sequence_length': 2048, 'initializer_range': 0.02, 'rms_norm_eps': 1e-6, 'use_cache': True, 'tie_word_embeddings': False, }, '1b': { 'vocab_size': 32000, 'hidden_size': 2048, 'intermediate_size': 5504, 'num_hidden_layers': 22, 'num_attention_heads': 16, 'max_sequence_length': 2048, 'initializer_range': 0.02, 'rms_norm_eps': 1e-6, 'use_cache': True, 'tie_word_embeddings': False, }, '3b': { 'vocab_size': 32000, 'hidden_size': 3200, 'intermediate_size': 8640, 'num_hidden_layers': 26, 'num_attention_heads': 32, 'max_sequence_length': 2048, 'initializer_range': 0.02, 'rms_norm_eps': 1e-6, 'use_cache': True, 'tie_word_embeddings': False, }, '7b': { 'vocab_size': 32000, 'hidden_size': 4096, 'intermediate_size': 11008, 'num_hidden_layers': 32, 'num_attention_heads': 32, 'max_sequence_length': 4096, 'initializer_range': 0.02, 'rms_norm_eps': 1e-6, 'use_cache': True, 'tie_word_embeddings': False, }, '13b': { 'vocab_size': 32000, 'hidden_size': 5120, 'intermediate_size': 13824, 'num_hidden_layers': 40, 'num_attention_heads': 40, 'max_sequence_length': 2048, 'initializer_range': 0.02, 'rms_norm_eps': 1e-6, 'use_cache': True, 'tie_word_embeddings': False, }, '30b': { 'vocab_size': 32000, 'hidden_size': 6656, 'intermediate_size': 17920, 'num_hidden_layers': 60, 'num_attention_heads': 52, 'max_sequence_length': 2048, 'initializer_range': 0.02, 'rms_norm_eps': 1e-6, 'use_cache': True, 'tie_word_embeddings': False, }, '65b': { 'vocab_size': 32000, 'hidden_size': 8192, 'intermediate_size': 22016, 'num_hidden_layers': 80, 'num_attention_heads': 64, 'max_sequence_length': 2048, 'initializer_range': 0.02, 'rms_norm_eps': 1e-5, 'use_cache': True, 'tie_word_embeddings': False, }, 'debug': { # A small model for debugging 'vocab_size': 32000, 'hidden_size': 256, 'intermediate_size': 256, 'num_hidden_layers': 2, 'num_attention_heads': 2, 'max_sequence_length': 2048, 'initializer_range': 0.02, 'rms_norm_eps': 1e-6, 'use_cache': True, 'tie_word_embeddings': False, }, } class LLaMAConfig(PretrainedConfig): model_type = "llama" def __init__( self, vocab_size=32000, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, max_sequence_length=4096, rms_norm_eps=1e-6, initializer_range=0.02, use_cache=True, bos_token_id=0, eos_token_id=1, resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, tie_word_embeddings=False, scan_attention=True, scan_mlp=True, scan_query_chunk_size=1024, scan_key_chunk_size=1024, scan_mlp_chunk_size=1024, scan_layers=True, param_scan_axis=0, mesh_dim=None, theta=10000, **kwargs, ): self.vocab_size = vocab_size self.hidden_size = hidden_size self.initializer_range = initializer_range self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.max_sequence_length = max_sequence_length self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.resid_pdrop = resid_pdrop self.embd_pdrop = embd_pdrop self.attn_pdrop = attn_pdrop self.scan_attention = scan_attention self.scan_mlp = scan_mlp self.scan_query_chunk_size = scan_query_chunk_size self.scan_key_chunk_size = scan_key_chunk_size self.scan_mlp_chunk_size = scan_mlp_chunk_size self.scan_layers = scan_layers self.param_scan_axis = param_scan_axis self.mesh_dim = mesh_dim self.theta = theta super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) @classmethod def get_default_config(cls, updates=None): config = function_args_to_config(cls.__init__) if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) return config @staticmethod def get_jax_mesh(axis_dims): return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'tp', 'sp')) @staticmethod def get_ranks_and_size(mesh): out = dict(mesh=mesh) mp_size = mesh.shape['tp'] * mesh.shape['sp'] mp_node_size = max(1, mp_size // jax.local_device_count()) dp_node_size = jax.process_count() // mp_node_size out.update(mp_node_size=mp_node_size, dp_node_size=dp_node_size) dp_node_rank = jax.process_index() // mp_node_size mp_node_rank = jax.process_index() % mp_node_size out.update(dp_node_rank=dp_node_rank, mp_node_rank=mp_node_rank) return out @staticmethod def get_partition_rules(scan_layers=False, scan_axis=0): """Parition rules are orderd, so that the beginning rules match first.""" if scan_layers: if scan_axis == 0: return ( # embeddings ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))), # atention ("attention/(wq|wk|wv)/kernel", PS(None, ("fsdp", "sp"), "tp")), ("attention/wo/kernel", PS(None, "tp", ("fsdp", "sp"))), # mlp ("feed_forward/w1/kernel", PS(None, ("fsdp", "sp"), "tp")), ("feed_forward/w2/kernel", PS(None, "tp", ("fsdp", "sp"))), ("feed_forward/w3/kernel", PS(None, ("fsdp", "sp"), "tp")), # layer norms ("attention_norm/kernel", PS(None, None)), ("ffn_norm/kernel", PS(None, None)), # output head ("transformer/ln_f/kernel", PS(None)), ("lm_head/kernel", PS(("fsdp", "sp"), "tp")), ('.*', PS(None)), ) elif scan_axis == 1: return ( # embeddings ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))), # atention ("attention/(wq|wk|wv)/kernel", PS(("fsdp", "sp"), None, "tp")), ("attention/wo/kernel", PS("tp", None, ("fsdp", "sp"))), # mlp ("feed_forward/w1/kernel", PS(("fsdp", "sp"), None, "tp")), ("feed_forward/w2/kernel", PS("tp", None, ("fsdp", "sp"))), ("feed_forward/w3/kernel", PS(("fsdp", "sp"), None, "tp")), # layer norms ("attention_norm/kernel", PS(None, None)), ("ffn_norm/kernel", PS(None, None)), # output head ("transformer/ln_f/kernel", PS(None)), ("lm_head/kernel", PS(("fsdp", "sp"), "tp")), ('.*', PS(None)), ) else: raise ValueError(f"Invalid scan_axis {scan_axis}") else: return ( # embeddings ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))), # atention ("attention/(wq|wk|wv)/kernel", PS(("fsdp", "sp"), "tp")), ("attention/wo/kernel", PS("tp", ("fsdp", "sp"))), # mlp ("feed_forward/w1/kernel", PS(("fsdp", "sp"), "tp")), ("feed_forward/w2/kernel", PS("tp", ("fsdp", "sp"))), ("feed_forward/w3/kernel", PS(("fsdp", "sp"), "tp")), # layer norms ("attention_norm/kernel", PS(None)), ("ffn_norm/kernel", PS(None)), # output head ("transformer/ln_f/kernel", PS(None)), ("lm_head/kernel", PS(("fsdp", "sp"), "tp")), ('.*', PS(None)), ) @staticmethod def get_weight_decay_exclusions(): return tuple() @staticmethod def get_frozen_param_exclusions(freeze_base): if freeze_base: return ("vte", "vision_head") else: return tuple() @staticmethod def rng_keys(): return ('params', 'dropout') @classmethod def load_config(cls, path): if path in LLAMA_STANDARD_CONFIGS: return cls.from_dict(LLAMA_STANDARD_CONFIGS[path]) load_type, load_path = path.split('::', 1) if load_type == 'pickle': return cls.from_dict(load_pickle(load_path)['llama_config']) elif load_type == 'json': with open_file(load_path, 'r') as fin: raw_config = fin.read() return cls.from_dict(json.loads(raw_config)) else: raise ValueError(f'Unsupported load config type: {load_type}') remat = nn_partitioning.remat logger = logging.get_logger(__name__) class RMSNorm(nn.Module): dim: int eps: float=1e-6 dtype: jnp.dtype=jnp.float32 param_dtype: jnp.dtype=jnp.float32 def setup(self) -> None: self.weight = self.param( 'kernel', nn.initializers.ones, (self.dim,), self.param_dtype, ) def _norm(self, x: jnp.ndarray) -> jnp.ndarray: return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps) def __call__(self, x: jnp.ndarray) -> jnp.ndarray: x = x.astype(jnp.promote_types(self.dtype, jnp.float32)) output = self._norm(x).astype(self.dtype) weight = jnp.asarray(self.weight, self.dtype) return output * weight def precompute_freqs_cis(dim: int, max_position_embedding: int, theta: float=10000.0, dtype: jnp.dtype=jnp.float32) -> jnp.ndarray: freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim)) t = np.arange(max_position_embedding) # type: ignore freqs = np.outer(t, freqs).astype(dtype) # type: ignore sin, cos = np.sin(freqs), np.cos(freqs) freqs_cis = np.complex64(cos + 1j * sin) return jnp.asarray(freqs_cis) def apply_rotary_emb( xq: jnp.ndarray, xk: jnp.ndarray, freqs_cis: jnp.ndarray, dtype: jnp.dtype=jnp.float32, ) -> Tuple[jnp.ndarray, jnp.ndarray]: reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1]) xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1]) # add head dim freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:])) xq_out = xq_ * freqs_cis xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1) xk_out = xk_ * freqs_cis xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1) return xq_out.astype(dtype), xk_out.astype(dtype) class FlaxLLaMAAttention(nn.Module): config: LLaMAConfig dtype: jnp.dtype=jnp.float32 param_dtype: jnp.dtype=jnp.float32 precision: Optional[Union[jax.lax.Precision, str]]=None def setup(self): config = self.config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads self.wq = nn.Dense( config.num_attention_heads*self.head_dim, dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), precision=self.precision, ) self.wk = nn.Dense( config.num_attention_heads*self.head_dim, dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), precision=self.precision, ) self.wv = nn.Dense( config.num_attention_heads*self.head_dim, dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), precision=self.precision, ) self.wo = nn.Dense( config.hidden_size, dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), precision=self.precision, ) self.resid_dropout = nn.Dropout(rate=config.resid_pdrop) self.causal_mask = make_causal_mask(jnp.ones((1, config.max_sequence_length), dtype="bool"), dtype="bool") self.freqs_cis = precompute_freqs_cis( self.head_dim, config.max_sequence_length, theta=config.theta, dtype=self.dtype, ) def _split_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) def _merge_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) @nn.compact def _concatenate_to_cache(self, key, value, query, attention_mask): # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable("cache", "cached_key") cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) if is_initialized: *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape # update key, value caches with our new 1d spatial slices cur_index = cache_index.value if query.shape[1] == 1: mesh = LLaMAConfig.get_jax_mesh(self.config.mesh_dim) def fn(cached_key, cached_value, key, value, cur_index): assert key.shape[1] == 1 and value.shape[1] == 1, (key.shape, value.shape) sp_size = max_length // mesh.shape['sp'] axis_index = jax.lax.axis_index('sp') cur_index = cur_index - axis_index * sp_size key, value = jax.lax.cond( jnp.logical_and(cur_index >= 0, cur_index < sp_size), lambda: ( cached_key.at[:, cur_index].set(key[:, -1]), cached_value.at[:, cur_index].set(value[:, -1]), ), lambda: (cached_key, cached_value), ) return key, value fn = shard_map( fn, mesh=mesh, in_specs=( PS(('dp', 'fsdp'), 'sp', 'tp', None), PS(('dp', 'fsdp'), 'sp', 'tp', None), PS(('dp', 'fsdp'), None, 'tp', None), PS(('dp', 'fsdp'), None, 'tp', None), PS() ), out_specs=( PS(('dp', 'fsdp'), 'sp', 'tp', None), PS(('dp', 'fsdp'), 'sp', 'tp', None) ), check_rep=False ) key, value = fn(cached_key.value, cached_value.value, key, value, cur_index) else: indices = (0,) * len(batch_dims) + (cur_index, 0, 0) key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) cached_key.value = key cached_value.value = value num_updated_cache_vectors = query.shape[1] cache_index.value = cache_index.value + num_updated_cache_vectors return key, value, attention_mask def __call__( self, hidden_states, attention_mask, segment_ids, position_ids, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, ): xq, xk, xv = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states) if xq.shape[1] == 1: xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), None, "tp")) else: xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), "sp", "tp")) xk = with_sharding_constraint(xk, PS(("dp", "fsdp"), "sp", "tp")) xv = with_sharding_constraint(xv, PS(("dp", "fsdp"), "sp", "tp")) xq = self._split_heads(xq) xk = self._split_heads(xk) xv = self._split_heads(xv) freqs_cis = jnp.take(self.freqs_cis, position_ids, axis=0) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype) dropout_rng = None if not deterministic and self.config.attn_pdrop > 0.0: dropout_rng = self.make_rng("dropout") if self.config.scan_attention and xq.shape[1] > max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size): # attention mask without nxn materlization, blockwise_attn will handle the rest attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) if self.has_variable("cache", "cached_key") or init_cache: xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask) # transform boolean mask into float mask attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) attn_weights = None ring_attention_sharded = shard_map( partial( ringattention, axis_name="sp", float32_logits=True, cache_idx=None, blockwise_kwargs=dict( causal_block_size=1, deterministic=deterministic, dropout_rng=dropout_rng, attn_pdrop=self.config.attn_pdrop, query_chunk_size=self.config.scan_query_chunk_size, key_chunk_size=self.config.scan_key_chunk_size, dtype=self.dtype, policy=jax.checkpoint_policies.nothing_saveable, precision=self.precision, prevent_cse=not self.config.scan_layers, ) ), mesh=LLaMAConfig.get_jax_mesh(self.config.mesh_dim), in_specs=( PS(("dp", "fsdp"), "sp", "tp", None), PS(("dp", "fsdp"), "sp", "tp", None), PS(("dp", "fsdp"), "sp", "tp", None), PS(("dp", "fsdp"), None, None, None), PS(("dp", "fsdp"), None), ), out_specs=PS(("dp", "fsdp"), "sp", "tp", None), check_rep=False ) attn_output = ring_attention_sharded(xq, xk, xv, attention_bias, segment_ids) attn_output = with_sharding_constraint(attn_output, PS(("dp", "fsdp"), "sp", "tp", None)) else: query_length, key_length = xq.shape[1], xk.shape[1] if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"]["cached_key"].shape[1] causal_mask = jnp.arange(max_decoder_length)[None] <= (jnp.arange(query_length) + mask_shift)[:, None] causal_mask = causal_mask[None, None] segment_mask = None else: causal_mask = self.causal_mask[:, :, :query_length, :key_length] if segment_ids is not None: segment_mask = segment_ids[:, :, None] == segment_ids[:, None, :] segment_mask = segment_mask[:, None] else: segment_mask = None batch_size = hidden_states.shape[0] causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask, segment_mask) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.has_variable("cache", "cached_key") or init_cache: xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask) q_sp_dim = None if xq.shape[1] == 1 else 'sp' attn_weights = None ring_attention_sharded = shard_map( partial(ringattention_inference, axis_name="sp"), mesh=LLaMAConfig.get_jax_mesh(self.config.mesh_dim), in_specs=( PS(("dp", "fsdp"), q_sp_dim, "tp", None), PS(("dp", "fsdp"), "sp", "tp", None), PS(("dp", "fsdp"), "sp", "tp", None), PS(("dp", "fsdp"), None, q_sp_dim, None) ), out_specs=PS(("dp", "fsdp"), q_sp_dim, "tp", None), check_rep=False ) attn_output = ring_attention_sharded( xq, xk, xv, attention_mask ) attn_output = self._merge_heads(attn_output) attn_output = self.wo(attn_output) attn_output = self.resid_dropout(attn_output, deterministic=deterministic) outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs class FlaxLLaMAMLP(nn.Module): config: LLaMAConfig dtype: jnp.dtype=jnp.float32 param_dtype: jnp.dtype=jnp.float32 precision: Optional[Union[jax.lax.Precision, str]]=None def setup(self) -> None: config = self.config self.w1 = nn.Dense( config.intermediate_size, dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), precision=self.precision, ) self.w2 = nn.Dense( config.hidden_size, dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), precision=self.precision, ) self.w3 = nn.Dense( config.intermediate_size, dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), precision=self.precision, ) self.dropout = nn.Dropout(rate=self.config.resid_pdrop) def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: x = self.w2(nn.silu(self.w1(x)) * self.w3(x)) x = self.dropout(x, deterministic=deterministic) return x class FlaxLLaMABlock(nn.Module): config: LLaMAConfig dtype: jnp.dtype=jnp.float32 param_dtype: jnp.dtype=jnp.float32 precision: Optional[Union[jax.lax.Precision, str]]=None def setup(self) -> None: attention_module = FlaxLLaMAAttention mlp_module = FlaxLLaMAMLP if self.config.scan_mlp: mlp_module = remat( mlp_module, static_argnums=(1,), policy=jax.checkpoint_policies.nothing_saveable, prevent_cse=not self.config.scan_layers, ) self.attention = attention_module( self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, ) self.feed_forward = mlp_module( self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, ) self.attention_norm = RMSNorm( self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype, ) self.ffn_norm = RMSNorm( self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype, ) def __call__( self, hidden_states, attention_mask=None, segment_ids=None, position_ids=None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, ): attn_outputs = self.attention( self.attention_norm(hidden_states), attention_mask, segment_ids, position_ids, deterministic, init_cache, output_attentions, ) attn_output = attn_outputs[0] hidden_states = hidden_states + attn_output feed_forward_input = self.ffn_norm(hidden_states) if self.config.scan_mlp and hidden_states.shape[1] >= self.config.scan_mlp_chunk_size: feed_forward_hidden_states = blockwise_feedforward( self.feed_forward, feed_forward_input, self.config.scan_mlp_chunk_size, pre_remat=True, ) else: feed_forward_hidden_states = self.feed_forward(feed_forward_input, deterministic) feed_forward_hidden_states = with_sharding_constraint(feed_forward_hidden_states, PS(("dp", "fsdp"), None, "tp")) hidden_states = hidden_states + feed_forward_hidden_states outputs = hidden_states if self.config.scan_layers: outputs = (outputs, None) return outputs class FlaxLLaMAPreTrainedModel(FlaxPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = LLaMAConfig base_model_prefix = "transformer" module_class: nn.Module = None def __init__( self, config: LLaMAConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, **kwargs, ): module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) segment_ids = None position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} if self.config.add_cross_attention: encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) encoder_attention_mask = attention_mask module_init_outputs = self.module.init( rngs, input_ids, attention_mask, segment_ids, position_ids, encoder_hidden_states, encoder_attention_mask, return_dict=False, ) else: module_init_outputs = self.module.init(rngs, input_ids, attention_mask, segment_ids, position_ids, return_dict=False) random_params = module_init_outputs["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params def init_cache(self, batch_size, max_length): r""" Args: batch_size (`int`): batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. max_length (`int`): maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized cache. """ # init input variables to retrieve cache input_ids = jnp.ones((batch_size, max_length)) attention_mask = jnp.ones_like(input_ids) segment_ids = None position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) init_variables = self.module.init( jax.random.PRNGKey(0), input_ids, attention_mask, segment_ids, position_ids, return_dict=False, init_cache=True ) return init_variables["cache"] @add_start_docstrings_to_model_forward("") def __call__( self, input_ids, attention_mask=None, segment_ids=None, position_ids=None, params: dict = None, past_key_values: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict batch_size, sequence_length = input_ids.shape if position_ids is None: if past_key_values is not None: raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False outputs = self.module.apply( inputs, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), segment_ids, jnp.array(position_ids, dtype="i4"), not train, False, output_attentions, output_hidden_states, return_dict, rngs=rngs, mutable=mutable, ) # add updated cache to model output if past_key_values is not None and return_dict: outputs, past_key_values = outputs outputs["past_key_values"] = unfreeze(past_key_values["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs, past_key_values = outputs outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] return outputs class FlaxLLaMABlockCollection(nn.Module): config: LLaMAConfig dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype=jnp.float32 precision: Optional[Union[jax.lax.Precision, str]]=None @nn.compact def __call__( self, hidden_states, attention_mask=None, segment_ids=None, position_ids=None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None block = FlaxLLaMABlock if self.config.scan_layers: initializing = self.is_mutable_collection('params') params_spec = ( self.config.param_scan_axis if initializing else nn_partitioning.ScanIn(self.config.param_scan_axis)) cache_spec = 0 hidden_states, _ = nn.scan( block, variable_axes={ 'params': params_spec, 'cache': cache_spec, 'intermediates': 0 }, split_rngs={ 'params': True, 'dropout': True }, in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), length=self.config.num_hidden_layers, metadata_params={nn.PARTITION_NAME: 'scan_decoder_layer'}, )(self.config, name='scan_decoder', dtype=self.dtype, param_dtype=self.param_dtype,)( hidden_states, attention_mask, segment_ids, position_ids, deterministic, init_cache, output_attentions, ) else: blocks = [ block( self.config, name=str(i), dtype=self.dtype, param_dtype=self.param_dtype, ) for i in range(self.config.num_hidden_layers) ] for block in blocks: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = block( hidden_states, attention_mask, segment_ids, position_ids, deterministic, init_cache, output_attentions, ) hidden_states = layer_outputs if output_attentions: all_attentions += (layer_outputs[1],) outputs = (hidden_states, all_hidden_states, all_attentions) return outputs class FlaxLLaMAModule(nn.Module): config: LLaMAConfig dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype=jnp.float32 precision: Optional[Union[jax.lax.Precision, str]]=None def setup(self): self.embed_dim = self.config.hidden_size self.wte = nn.Embed( self.config.vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), dtype=self.dtype, param_dtype=self.param_dtype, ) self.dropout = nn.Dropout(rate=self.config.embd_pdrop) self.h = FlaxLLaMABlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision) self.ln_f = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype) def __call__( self, input_ids, attention_mask, segment_ids, position_ids, deterministic=True, init_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): input_embeds = self.wte(input_ids.astype("i4")) assert input_embeds.shape[1] <= self.config.max_sequence_length, f"Input sequence length {input_embeds.shape[1]} larger than max supported sequence length {self.config.max_sequence_length}" hidden_states = self.dropout(input_embeds, deterministic=deterministic) outputs = self.h( hidden_states, attention_mask, segment_ids=segment_ids, position_ids=position_ids, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] hidden_states = self.ln_f(hidden_states) if output_hidden_states: all_hidden_states = outputs[1] + (hidden_states,) outputs = (hidden_states, all_hidden_states) + outputs[2:] else: outputs = (hidden_states,) + outputs[1:] if not return_dict: return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutput( last_hidden_state=hidden_states, hidden_states=outputs[1], attentions=outputs[-1], ) class FlaxLLaMAForCausalLMModule(nn.Module): config: LLaMAConfig dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype=jnp.float32 precision: Optional[Union[jax.lax.Precision, str]]=None def setup(self): self.transformer = FlaxLLaMAModule(self.config, dtype=self.dtype) self.lm_head = nn.Dense( self.config.vocab_size, dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), precision=self.precision, ) def __call__( self, input_ids, attention_mask=None, segment_ids=None, position_ids=None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): batch_size, seq_length = input_ids.shape if attention_mask is None: attention_mask = jnp.ones_like(input_ids) if position_ids is None: position_ids = jnp.arange(seq_length, dtype=jnp.int32)[None].repeat(batch_size, axis=0) outputs = self.transformer( input_ids, attention_mask, segment_ids, position_ids, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] if self.config.tie_word_embeddings: shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) else: lm_logits = self.lm_head(hidden_states) if not return_dict: return (lm_logits,) + outputs[1:] return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) @add_start_docstrings("", "") class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel): module_class = FlaxLLaMAForCausalLMModule def prepare_inputs_for_generation( self, input_ids, max_length, attention_mask: Optional[jax.Array] = None, ): # initializing the cache batch_size, seq_length = input_ids.shape past_key_values = self.init_cache(batch_size, max_length) extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) else: position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) return { "past_key_values": past_key_values, "attention_mask": extended_attention_mask, "position_ids": position_ids, } def update_inputs_for_generation(self, model_outputs, model_kwargs): model_kwargs["past_key_values"] = model_outputs.past_key_values model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 return model_kwargs ================================================ FILE: lwm/train.py ================================================ import pprint import os from functools import partial from tqdm import tqdm, trange import numpy as np from absl.app import run import absl.logging as logging import tux import jax import flax import jax.numpy as jnp from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as PS from flax.training.train_state import TrainState from transformers import AutoTokenizer from lwm.data import DatasetFactory from tux import ( JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name, set_random_seed, average_metrics, get_mask, make_shard_and_gather_fns, with_sharding_constraint, define_flags_with_default, OptimizerFactory, StreamingCheckpointer ) from lwm.llama import LLaMAConfig, FlaxLLaMAForCausalLMModule from lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLMModule FLAGS, FLAGS_DEF = define_flags_with_default( modality='text', use_data_sharded_loader=True, seed=42, mesh_dim='1,-1,1,1', dtype='fp32', total_steps=10000, load_llama_config='', update_llama_config='', load_checkpoint='', load_dataset_state='', log_freq=50, save_model_freq=0, save_milestone_freq=0, eval_steps=0, tokenizer='LargeWorldModel/LWM-Text-1M', train_dataset=DatasetFactory.get_default_config(), eval_dataset=DatasetFactory.get_default_config(), optimizer=OptimizerFactory.get_default_config(), checkpointer=StreamingCheckpointer.get_default_config(), llama=VideoLLaMAConfig.get_default_config(), logger=tux.WandBLogger.get_default_config(), log_all_worker=False, jax_distributed=JaxDistributedConfig.get_default_config(), autoresume=False, ) def main(argv): JaxDistributedConfig.initialize(FLAGS.jax_distributed) variant = tux.get_user_flags(FLAGS, FLAGS_DEF) flags_config_dict = tux.user_flags_to_config_dict(FLAGS, FLAGS_DEF) logger = tux.WandBLogger( config=FLAGS.logger, variant=variant, enable=FLAGS.log_all_worker or (jax.process_index() == 0), ) set_random_seed(FLAGS.seed) if jax.process_index() == 0: output_dir = logger.output_dir else: output_dir = os.path.join(logger.output_dir, logger.experiment_id) if FLAGS.modality == 'text': config_cls = LLaMAConfig llama_cls = FlaxLLaMAForCausalLMModule elif FLAGS.modality == 'vision,text': config_cls = VideoLLaMAConfig llama_cls = FlaxVideoLLaMAForCausalLMModule else: raise ValueError(f"Unsupported modality: {FLAGS.modality}") mesh = config_cls.get_jax_mesh(FLAGS.mesh_dim) node_info = config_cls.get_ranks_and_size(mesh) tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer) dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer, node_info=node_info) if FLAGS.autoresume and tux.check_exists(output_dir): logging.info('Found existing output. Resuming dataset from latest checkpoint...') resume_path = f"{output_dir}/dataset.pkl" dataset.load_state_dict(tux.load_pickle(resume_path)) elif FLAGS.load_dataset_state != '': dataset.load_state_dict(tux.load_pickle(FLAGS.load_dataset_state)) if FLAGS.eval_steps > 0: eval_dataset = DatasetFactory.load_dataset( FLAGS.eval_dataset, dataset.tokenizer ) eval_iterator = iter(eval_dataset) seq_length = dataset.seq_length if FLAGS.load_llama_config != '': llama_config = config_cls.load_config(FLAGS.load_llama_config) updates = config_cls(**FLAGS.llama) llama_config.update(dict( scan_attention=updates.scan_attention, scan_mlp=updates.scan_mlp, scan_query_chunk_size=updates.scan_query_chunk_size, scan_key_chunk_size=updates.scan_key_chunk_size, scan_mlp_chunk_size=updates.scan_mlp_chunk_size, scan_layers=updates.scan_layers, param_scan_axis=updates.param_scan_axis, )) else: llama_config = config_cls(**FLAGS.llama) if FLAGS.update_llama_config != '': llama_config.update(dict(eval(FLAGS.update_llama_config))) llama_config.update(dict( bos_token_id=dataset.tokenizer.bos_token_id, eos_token_id=dataset.tokenizer.eos_token_id, )) if llama_config.vocab_size < dataset.vocab_size: llama_config.update(dict(vocab_size=dataset.vocab_size)) llama_config.update(dict(mesh_dim=FLAGS.mesh_dim)) model = llama_cls( llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype) ) optimizer, optimizer_info = OptimizerFactory.get_optimizer( FLAGS.optimizer, get_mask(config_cls.get_weight_decay_exclusions()), None, ) def create_trainstate_from_params(params): return TrainState.create(params=params, tx=optimizer, apply_fn=None) def init_fn(rng): rng_generator = JaxRNG(rng) batch = 512 if FLAGS.modality == 'text': params = model.init( input_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32), position_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32), attention_mask=jnp.ones((batch, seq_length), dtype=jnp.int32), rngs=rng_generator(llama_config.rng_keys()), ) elif FLAGS.modality == 'vision,text': params = model.init( input_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32), vision_masks=jnp.zeros((batch, seq_length), dtype=bool), position_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32), attention_mask=jnp.ones((batch, seq_length), dtype=jnp.int32), rngs=rng_generator(llama_config.rng_keys()), ) else: raise ValueError(f"Unsupported modality: {FLAGS.modality}") return TrainState.create(params=params, tx=optimizer, apply_fn=None) def train_step(train_state, rng, batch): rng_generator = JaxRNG(rng) batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp')) def loss_and_accuracy(params): if FLAGS.modality == 'text': logits = model.apply( params, batch['input_tokens'], deterministic=False, rngs=rng_generator(llama_config.rng_keys()), ).logits loss, acc = cross_entropy_loss_and_accuracy( logits, batch['target_tokens'], batch['loss_masks'] ) metrics = dict(acc=acc) return loss, metrics elif FLAGS.modality == 'vision,text': vision_logits, text_logits = model.apply( params, batch['input_tokens'], batch['input_vision_masks'], deterministic=False, rngs=rng_generator(llama_config.rng_keys()), ).logits vision_loss, vision_acc = cross_entropy_loss_and_accuracy( vision_logits, jnp.where(batch['target_vision_masks'], batch['target_tokens'], 0), batch['loss_masks'] * batch['target_vision_masks'] ) text_loss, text_acc = cross_entropy_loss_and_accuracy( text_logits, jnp.where(batch['target_vision_masks'], 0, batch['target_tokens']), batch['loss_masks'] * (1.0 - batch['target_vision_masks']) ) loss = 0.5 * (vision_loss + text_loss) metrics = dict( vision_loss=vision_loss, vision_acc=vision_acc, text_loss=text_loss, text_acc=text_acc, ) else: raise ValueError(f"Unsupported modality: {FLAGS.modality}") return loss, metrics grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True) (loss, loss_metrics), grads = grad_fn(train_state.params) train_state = train_state.apply_gradients(grads=grads) metrics = dict( loss=loss, learning_rate=optimizer_info['learning_rate_schedule'](train_state.step), param_norm=global_norm(train_state.params), gradient_norm=global_norm(grads), **loss_metrics ) return train_state, rng_generator(), metrics def eval_step(train_state, rng, batch): rng_generator = JaxRNG(rng) batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp')) if FLAGS.modality == 'text': logits = model.apply( train_state.params, batch['input_tokens'], deterministic=True, rngs=rng_generator(llama_config.rng_keys()), ).logits loss, acc = cross_entropy_loss_and_accuracy( logits, batch['target_tokens'], batch['loss_masks'] ) metrics = dict( eval_loss=loss, eval_acc=acc, ) elif FLAGS.modality == 'vision,text': vision_logits, text_logits = model.apply( train_state.params, batch['input_tokens'], batch['input_vision_masks'], deterministic=True, rngs=rng_generator(llama_config.rng_keys()), ).logits vision_loss, vision_acc = cross_entropy_loss_and_accuracy( vision_logits, jnp.where(batch['target_vision_masks'], batch['target_tokens'], 0), batch['loss_masks'] * batch['target_vision_masks'] ) text_loss, text_acc = cross_entropy_loss_and_accuracy( text_logits, jnp.where(batch['target_vision_masks'], 0, batch['target_tokens']), batch['loss_masks'] * (1.0 - batch['target_vision_masks']) ) loss = 0.5 * (vision_loss + text_loss) metrics = dict( eval_loss=loss, eval_vision_accuracy=vision_acc, eval_vision_loss=vision_loss, eval_text_accuracy=text_acc, eval_text_loss=text_loss, ) return rng_generator(), metrics train_state_shapes = jax.eval_shape(init_fn, next_rng()) train_state_partition = match_partition_rules( config_cls.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), train_state_shapes ) shard_fns, gather_fns = make_shard_and_gather_fns( train_state_partition, train_state_shapes ) checkpointer = StreamingCheckpointer( FLAGS.checkpointer, logger.output_dir, enable=jax.process_index() == 0, ) sharded_init_fn = pjit( init_fn, in_shardings=PS(), out_shardings=train_state_partition ) sharded_create_trainstate_from_params = pjit( create_trainstate_from_params, in_shardings=(train_state_partition.params, ), out_shardings=train_state_partition, donate_argnums=(0, ), ) if FLAGS.use_data_sharded_loader: batch_spec = PS(('dp', 'fsdp'), 'sp') else: batch_spec = PS() sharded_train_step = pjit( train_step, in_shardings=(train_state_partition, PS(), batch_spec), out_shardings=(train_state_partition, PS(), PS()), donate_argnums=(0, 1), ) sharded_eval_step = pjit( eval_step, in_shardings=(train_state_partition, PS(), PS()), out_shardings=(PS(), PS()), donate_argnums=(1,), ) def save_checkpoint(train_state, milestone=False): step = int(jax.device_get(train_state.step)) metadata = dict( step=step, variant=variant, flags=flags_config_dict, llama_config=llama_config.to_dict(), ) checkpointer.save_all( train_state=train_state, gather_fns=gather_fns, metadata=metadata, dataset=dataset.get_state_dict(), milestone=milestone, ) with mesh: train_state, restored_params = None, None if FLAGS.autoresume and tux.check_exists(output_dir): logging.info('Found existing output. Resuming model from latest checkpoint...') resume_path = f"trainstate::{output_dir}/streaming_train_state" train_state, restored_params = checkpointer.load_trainstate_checkpoint( resume_path, train_state_shapes, shard_fns, max_buffer_size=32 * 2 ** 30 ) elif FLAGS.load_checkpoint != '': train_state, restored_params = checkpointer.load_trainstate_checkpoint( FLAGS.load_checkpoint, train_state_shapes, shard_fns, max_buffer_size=32 * 2 ** 30 ) if train_state is None and restored_params is None: # Initialize from scratch train_state = sharded_init_fn(next_rng()) elif train_state is None and restored_params is not None: # Restore from params but initialize train_state train_state = sharded_create_trainstate_from_params(flax.core.unfreeze(restored_params)) del restored_params start_step = int(jax.device_get(train_state.step)) if FLAGS.save_model_freq > 0: save_checkpoint(train_state) sharded_rng = next_rng() step_counter = trange(start_step, FLAGS.total_steps, ncols=0) for step, (batch, dataset_metrics) in zip(step_counter, dataset): train_state, sharded_rng, metrics = sharded_train_step( train_state, sharded_rng, batch ) if step % FLAGS.log_freq == 0: if FLAGS.eval_steps > 0: eval_metric_list = [] for _ in range(FLAGS.eval_steps): eval_batch, _ = next(eval_iterator) sharded_rng, eval_metrics = sharded_eval_step( train_state, sharded_rng, eval_batch ) eval_metrics = jax.device_get(eval_metrics) eval_metric_list.append(eval_metrics) metrics.update(average_metrics(eval_metric_list)) log_metrics = {"step": step} log_metrics.update(metrics) log_metrics.update(dataset_metrics) log_metrics = jax.device_get(log_metrics) logger.log(log_metrics) tqdm.write("\n" + pprint.pformat(log_metrics) + "\n") if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0: save_checkpoint(train_state, milestone=True) elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0: save_checkpoint(train_state) if FLAGS.save_model_freq > 0: save_checkpoint(train_state) if __name__ == "__main__": run(main) ================================================ FILE: lwm/vision_chat.py ================================================ from absl.app import run import math from tqdm import tqdm from PIL import Image import decord from functools import cached_property import numpy as np import jax from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as PS from transformers import GenerationConfig, AutoTokenizer from tux import ( define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig, set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng, match_partition_rules, make_shard_and_gather_fns, with_sharding_constraint, tree_apply, open_file ) from lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM from lwm.vqgan import VQGAN FLAGS, FLAGS_DEF = define_flags_with_default( prompt="", input_file="", vqgan_checkpoint="", temperature=0.2, max_n_frames=8, seed=1234, mesh_dim='1,-1,1,1', dtype='fp32', load_llama_config='', update_llama_config='', load_checkpoint='', tokenizer='LargeWorldModel/LWM-Text-1M', llama=VideoLLaMAConfig.get_default_config(), jax_distributed=JaxDistributedConfig.get_default_config(), ) class Sampler: def __init__(self): self.mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim) self.vqgan = VQGAN(FLAGS.vqgan_checkpoint, replicate=False) self.prefix_tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer, truncation_side='left', padding_side='left') self.tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer) self.n_tokens_per_frame = 257 self.min_buffer_size = 256 self.sharded_rng = next_rng() self._load_model() @property def block_size(self): return max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) * self.mesh.shape['sp'] @property def data_dim(self): return self.mesh.shape['dp'] * self.mesh.shape['fsdp'] def _process_frame(self, image, size): width, height = image.size if width < height: new_width = size new_height = int(size * height / width) else: new_height = size new_width = int(size * width / height) image = image.resize((new_width, new_height)) left = (new_width - size) / 2 top = (new_height - size) / 2 right = (new_width + size) / 2 bottom = (new_height + size) / 2 image = image.crop((left, top, right, bottom)) return np.array(image, dtype=np.float32) / 127.5 - 1 def _read_process_vision(self, path, max_n_frames): f = open_file(path, 'rb') if path.endswith('.png') or path.endswith('.jpg'): image = Image.open(f).convert('RGB') vision = self._process_frame(image, 256)[None] else: vr = decord.VideoReader(f, ctx=decord.cpu(0)) duration = len(vr) if duration <= max_n_frames: frame_id_list = list(range(duration)) else: frame_id_list = np.linspace(0, duration - 1, max_n_frames, dtype=int).tolist() video = vr.get_batch(frame_id_list).asnumpy() vision = np.stack([self._process_frame(Image.fromarray(frame), 256) for frame in video]) B = 1 encodings = [] for i in range(0, len(vision), 1): v = vision[i:i + B] if len(v) % B == 0: n_pad = 0 else: n_pad = B - len(v) % B v = np.pad(v, ((n_pad, 0), (0, 0), (0, 0), (0, 0))) enc = jax.device_get(self.vqgan.encode(v))[1].astype(int) enc = enc[n_pad:] for t in range(len(enc)): encodings.extend(enc[t].reshape(-1).tolist()) if t == len(enc) - 1: encodings.append(8193) else: encodings.append(8192) return encodings def construct_input(self, prompts, max_n_frames): max_input_length = max_n_frames * self.n_tokens_per_frame + self.min_buffer_size max_input_length = int(math.ceil(max_input_length / self.block_size) * self.block_size) vision_start = self.tokenizer.encode('') vision_end = self.tokenizer.encode('') input_ids = np.zeros((len(prompts), max_input_length), dtype=int) vision_masks = np.zeros((len(prompts), max_input_length), dtype=bool) attention_mask = np.zeros((len(prompts), max_input_length), dtype=int) for i, prompt in enumerate(tqdm(prompts)): vision = self._read_process_vision(prompt['input_path'], max_n_frames) text_1 = self.tokenizer.encode(f"You are a helpful assistant. USER: {prompt['question']}\n") tail = self.tokenizer.encode(" ASSISTANT:") tokens, vm = [], [] tokens.extend(text_1) vm.extend([False] * len(text_1)) tokens.extend(vision_start) vm.extend([False] * len(vision_start)) tokens.extend(vision) vm.extend([True] * len(vision)) tokens.extend(vision_end) vm.extend([False] * len(vision_end)) tokens.extend(tail) vm.extend([False] * len(tail)) assert len(tokens) < max_input_length, (len(tokens), max_input_length) assert len(tokens) == len(vm) input_ids[i, -len(tokens):] = tokens vision_masks[i, -len(tokens):] = vm attention_mask[i, -len(tokens):] = 1 return { 'input_ids': input_ids, 'vision_masks': vision_masks, 'attention_mask': attention_mask } def _load_model(self): if FLAGS.load_llama_config != '': llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config) updates = VideoLLaMAConfig(**FLAGS.llama) llama_config.update(dict( scan_attention=updates.scan_attention, scan_mlp=updates.scan_mlp, scan_query_chunk_size=updates.scan_query_chunk_size, scan_key_chunk_size=updates.scan_key_chunk_size, scan_mlp_chunk_size=updates.scan_mlp_chunk_size, scan_layers=updates.scan_layers, param_scan_axis=updates.param_scan_axis, )) else: llama_config = VideoLLaMAConfig(**FLAGS.llama) if FLAGS.update_llama_config != '': llama_config.update(dict(eval(FLAGS.update_llama_config))) llama_config.update(dict( bos_token_id=self.tokenizer.bos_token_id, eos_token_id=self.tokenizer.eos_token_id, )) llama_config.update(dict(mesh_dim=FLAGS.mesh_dim)) self.config = llama_config self.model = FlaxVideoLLaMAForCausalLM( llama_config, input_shape=(512, self.block_size), seed=FLAGS.seed, _do_init=False, dtype=get_float_dtype_by_name(FLAGS.dtype), ) with jax.default_device(jax.devices("cpu")[0]): _, self.params = StreamingCheckpointer.load_trainstate_checkpoint( FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30 ) self.model_ps = match_partition_rules( VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params ) shard_fns, _ = make_shard_and_gather_fns( self.model_ps, get_float_dtype_by_name(FLAGS.dtype) ) with self.mesh: self.params = tree_apply(shard_fns, self.params) @cached_property def _forward_generate(self): def fn(params, rng, batch): batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp')) rng_generator = JaxRNG(rng) output = self.model.generate( batch['input_ids'], vision_masks=batch['vision_masks'], attention_mask=batch['attention_mask'], params=params['params'], prng_key=rng_generator(), generation_config=GenerationConfig( max_new_tokens=self.block_size, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, temperature=FLAGS.temperature, do_sample=True, ) ).sequences[:, batch['input_ids'].shape[1]:] return output, rng_generator() return pjit( fn, in_shardings=(self.model_ps, PS(), PS()), out_shardings=(PS(), PS()) ) def __call__(self, prompts, max_n_frames): batch = self.construct_input(prompts, max_n_frames) with self.mesh: output, self.sharded_rng = self._forward_generate( self.params, self.sharded_rng, batch ) output = jax.device_get(output) output_text = [] for text in list(self.tokenizer.batch_decode(output, skip_special_tokens=True)): if self.tokenizer.eos_token in text: text = text.split(self.tokenizer.eos_token, maxsplit=1)[0] output_text.append(text) return output_text def main(argv): assert FLAGS.prompt != '' assert FLAGS.input_file != '' JaxDistributedConfig.initialize(FLAGS.jax_distributed) set_random_seed(FLAGS.seed) prompts = [{'input_path': FLAGS.input_file, 'question': FLAGS.prompt}] sampler = Sampler() output = sampler(prompts, FLAGS.max_n_frames)[0] print(f"Question: {FLAGS.prompt}\nAnswer: {output}") if __name__ == "__main__": run(main) ================================================ FILE: lwm/vision_generation.py ================================================ from absl.app import run from tqdm import tqdm import imageio import numpy as np from PIL import Image from transformers import GenerationConfig, AutoTokenizer import jax import jax.numpy as jnp from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as PS from tux import ( define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig, set_random_seed, get_float_dtype_by_name, JaxRNG, match_partition_rules, make_shard_and_gather_fns, with_sharding_constraint, tree_apply, next_rng ) from lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM from lwm.vqgan import VQGAN FLAGS, FLAGS_DEF = define_flags_with_default( prompt='Fireworks over the city', output_file='', temperature_image=1.0, temperature_video=1.0, top_k_image=8192, top_k_video=100, cfg_scale_image=1.0, cfg_scale_video=1.0, vqgan_checkpoint='', n_frames=1, seed=1234, mesh_dim='1,-1,1,1', dtype='fp32', load_llama_config='', update_llama_config='', load_checkpoint='', tokenizer='LargeWorldModel/LWM-Text-1M', llama=VideoLLaMAConfig.get_default_config(), jax_distributed=JaxDistributedConfig.get_default_config(), ) def main(argv): assert FLAGS.output_file != '' if FLAGS.output_file.endswith('mp4'): assert FLAGS.n_frames > 1 elif FLAGS.output_file.endswith('png') or FLAGS.output_file.endswith('jpg'): assert FLAGS.n_frames == 1 else: raise ValueError(f"Unsupported output file extension: {FLAGS.output_file}") JaxDistributedConfig.initialize(FLAGS.jax_distributed) set_random_seed(FLAGS.seed) tokens_per_frame = 257 vqgan = VQGAN(FLAGS.vqgan_checkpoint, replicate=False) mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim) tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer) prefix_tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer, truncation_side='left', padding_side='left') if FLAGS.load_llama_config != '': llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config) updates = VideoLLaMAConfig(**FLAGS.llama) llama_config.update(dict( scan_attention=updates.scan_attention, scan_mlp=updates.scan_mlp, scan_query_chunk_size=updates.scan_query_chunk_size, scan_key_chunk_size=updates.scan_key_chunk_size, scan_mlp_chunk_size=updates.scan_mlp_chunk_size, scan_layers=updates.scan_layers, param_scan_axis=updates.param_scan_axis, )) else: llama_config = VideoLLaMAConfig(**FLAGS.llama) if FLAGS.update_llama_config != '': llama_config.update(dict(eval(FLAGS.update_llama_config))) llama_config.update(dict( bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, )) llama_config.update(dict(mesh_dim=FLAGS.mesh_dim)) with jax.default_device(jax.devices("cpu")[0]): _, params = StreamingCheckpointer.load_trainstate_checkpoint( FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30 ) model = FlaxVideoLLaMAForCausalLM( llama_config, input_shape=(512, 8192), seed=FLAGS.seed, _do_init=False, dtype=get_float_dtype_by_name(FLAGS.dtype), ) model_ps = match_partition_rules( VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), params ) shard_fns, _ = make_shard_and_gather_fns( model_ps, get_float_dtype_by_name(FLAGS.dtype) ) with mesh: params = tree_apply(shard_fns, params) def _forward_generate(params, rng, batch, n_tokens, cfg_scale, top_k, temperature): batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp')) cfg_scales = jnp.ones((batch['input_ids'].shape[0] // 2,), dtype=jnp.float32) * cfg_scale cfg_scales = with_sharding_constraint(cfg_scales, PS(('dp', 'fsdp'))) rng_generator = JaxRNG(rng) output = model.generate_vision( batch['input_ids'], cfg_scales, attention_mask=batch['attention_mask'], vision_masks=batch['vision_masks'], params=params['params'], prng_key=rng_generator(), generation_config=GenerationConfig( max_new_tokens=n_tokens, min_new_tokens=n_tokens, pad_token_id=tokenizer.pad_token_id, temperature=temperature, do_sample=True, top_k=top_k, ) ).sequences[:, batch['input_ids'].shape[1]:] return output, rng_generator() _sharded_forward_generate = pjit( _forward_generate, in_shardings=(model_ps, PS(), PS()), out_shardings=(PS(), PS()), static_argnums=(3, 4, 5, 6) ) # Generate an image or first frame (for video) def generate_first_frame(prompts, max_input_length): nonlocal sharded_rng uncond_prompts = [""] * len(prompts) prompts = prompts + uncond_prompts inputs = prefix_tokenizer( prompts, padding='max_length', truncation=True, max_length=max_input_length, return_tensors='np' ) batch = dict( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, vision_masks=np.zeros(inputs.input_ids.shape, dtype=bool), ) with mesh: output, sharded_rng = _sharded_forward_generate( params, sharded_rng, batch, tokens_per_frame, FLAGS.cfg_scale_image, FLAGS.top_k_image, FLAGS.temperature_image ) output = jax.device_get(output) output = np.split(output, 2, axis=0)[0] output = output.reshape(len(prompts) // 2, tokens_per_frame) image = vqgan.decode(output[:, :-1].reshape(-1, 16, 16)) image = ((jax.device_get(image) + 1) * 127.5).astype(np.uint8) return output, image sharded_rng = next_rng() prompts = [FLAGS.prompt] entries = [] for prompt in prompts: entries.append({ 'caption': prompt, 'prompt': f"You are a helpful assistant. USER: Generate an image of {prompt} ASSISTANT: ", }) B = 1 images, image_encodings = [], [] for i in tqdm(list(range(0, len(entries), B))): entries_i = entries[i:i + B] prompts = [entry['prompt'] for entry in entries_i] img_enc, img = generate_first_frame(prompts, max_input_length=128) image_encodings.extend(img_enc) images.extend(img) if FLAGS.n_frames == 1: image = images[0] Image.fromarray(image).save(FLAGS.output_file) return # Generate the rest of the video def generate_video_pred(prompts, images, max_input_length): nonlocal sharded_rng images = np.concatenate([images, images], axis=0) uncond_prompts = [""] * len(prompts) prompts = prompts + uncond_prompts inputs = prefix_tokenizer( prompts, padding='max_length', truncation=True, max_length=max_input_length, return_tensors='np' ) batch = dict( input_ids=np.concatenate([inputs.input_ids, images], axis=1), attention_mask=np.concatenate([inputs.attention_mask, np.ones(images.shape, dtype=inputs.attention_mask.dtype)], axis=1), vision_masks=np.concatenate([ np.zeros(inputs.input_ids.shape, dtype=bool), np.ones(images.shape, dtype=bool) ], axis=1), ) with mesh: output, sharded_rng = _sharded_forward_generate( params, sharded_rng, batch, (FLAGS.n_frames - 1) * tokens_per_frame, FLAGS.cfg_scale_video, FLAGS.top_k_video, FLAGS.temperature_video ) output = jax.device_get(output) output = np.split(output, 2, axis=0)[0] output = output.reshape(len(prompts) // 2, FLAGS.n_frames - 1, tokens_per_frame) output = np.concatenate([images[:len(prompts) // 2, None], output], axis=1) output = output[:, :, :-1].reshape(-1, FLAGS.n_frames, 16, 16) vision = [] for v in output: v = vqgan.decode(v) v = ((jax.device_get(v) + 1) * 127.5).astype(np.uint8) vision.append(v) return vision new_entries = [] for img_enc, entry in zip(image_encodings, entries): new_entries.append({ 'caption': entry['caption'], 'prompt': f"You are a helpful assistant. USER: Generate a video of {entry['caption']} ASSISTANT: ", 'image': np.array(img_enc, dtype=np.int32), }) entries = new_entries B = 1 videos = [] for i in tqdm(list(range(0, len(entries), B))): entries_i = entries[i:i + B] prompts = [entry['prompt'] for entry in entries_i] images = np.array([entry['image'] for entry in entries_i], dtype=np.int32) videos.extend(generate_video_pred(prompts, images, max_input_length=128)) video = videos[0] writer = imageio.get_writer(FLAGS.output_file, fps=4) for frame in video: writer.append_data(frame) writer.close() print('done') if __name__ == "__main__": run(main) ================================================ FILE: lwm/vision_llama.py ================================================ from typing import Any, Dict, List, Optional, Tuple, Union import json import warnings import copy import jax import jax.numpy as jnp from jax import lax from jax.sharding import PartitionSpec as PS import flax.linen as nn from flax.core.frozen_dict import unfreeze, freeze from flax.traverse_util import flatten_dict, unflatten_dict from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel from transformers.generation.flax_utils import SampleState, FlaxLogitsProcessorList, FlaxSampleOutput, logger from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers import GenerationConfig from tux import load_pickle, open_file from lwm.llama import LLaMAConfig, LLAMA_STANDARD_CONFIGS, FlaxLLaMABlockCollection, RMSNorm VIDEO_LLAMA_STANDARD_CONFIGS = LLAMA_STANDARD_CONFIGS class VideoLLaMAConfig(LLaMAConfig): model_type = "video_llama" def __init__(self, vision_vocab_size=8448, tie_vision_embeddings=False, sample_mode='all', **kwargs): super().__init__(**kwargs) self.vision_vocab_size = vision_vocab_size # 8192 + 256 self.tie_vision_embeddings = tie_vision_embeddings self.sample_mode = sample_mode @staticmethod def get_partition_rules(scan_layers=False, scan_axis=0): """Parition rules are orderd, so that the beginning rules match first.""" if scan_layers: if scan_axis == 0: return ( # embeddings ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))), ("transformer/vte/embedding", PS("tp", ("fsdp", "sp"))), # atention ("attention/(wq|wk|wv)/kernel", PS(None, ("fsdp", "sp"), "tp")), ("attention/wo/kernel", PS(None, "tp", ("fsdp", "sp"))), # mlp ("feed_forward/w1/kernel", PS(None, ("fsdp", "sp"), "tp")), ("feed_forward/w2/kernel", PS(None, "tp", ("fsdp", "sp"))), ("feed_forward/w3/kernel", PS(None, ("fsdp", "sp"), "tp")), # layer norms ("attention_norm/kernel", PS(None, None)), ("ffn_norm/kernel", PS(None, None)), # output head ("transformer/ln_f/kernel", PS(None)), ("lm_head/kernel", PS(("fsdp", "sp"), "tp")), ("vision_head/kernel", PS(("fsdp", "sp"), "tp")), ('.*', PS(None)), ) elif scan_axis == 1: return ( # embeddings ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))), ("transformer/vte/embedding", PS("tp", ("fsdp", "sp"))), # atention ("attention/(wq|wk|wv)/kernel", PS(("fsdp", "sp"), None, "tp")), ("attention/wo/kernel", PS("tp", None, ("fsdp", "sp"))), # mlp ("feed_forward/w1/kernel", PS(("fsdp", "sp"), None, "tp")), ("feed_forward/w2/kernel", PS("tp", None, ("fsdp", "sp"))), ("feed_forward/w3/kernel", PS(("fsdp", "sp"), None, "tp")), # layer norms ("attention_norm/kernel", PS(None, None)), ("ffn_norm/kernel", PS(None, None)), # output head ("transformer/ln_f/kernel", PS(None)), ("lm_head/kernel", PS(("fsdp", "sp"), "tp")), ("vision_head/kernel", PS(("fsdp", "sp"), "tp")), ('.*', PS(None)), ) else: raise ValueError(f"Invalid scan_axis {scan_axis}") else: return ( # embeddings ("transformer/wte/embedding", PS("tp", ("fsdp", "sp"))), ("transformer/vte/embedding", PS("tp", ("fsdp", "sp"))), # atention ("attention/(wq|wk|wv)/kernel", PS(("fsdp", "sp"), "tp")), ("attention/wo/kernel", PS("tp", ("fsdp", "sp"))), # mlp ("feed_forward/w1/kernel", PS(("fsdp", "sp"), "tp")), ("feed_forward/w2/kernel", PS("tp", ("fsdp", "sp"))), ("feed_forward/w3/kernel", PS(("fsdp", "sp"), "tp")), # layer norms ("attention_norm/kernel", PS(None)), ("ffn_norm/kernel", PS(None)), # output head ("transformer/ln_f/kernel", PS(None)), ("lm_head/kernel", PS(("fsdp", "sp"), "tp")), ("vision_head/kernel", PS(("fsdp", "sp"), "tp")), ('.*', PS(None)), ) @classmethod def load_config(cls, path): if path in VIDEO_LLAMA_STANDARD_CONFIGS: return cls.from_dict(VIDEO_LLAMA_STANDARD_CONFIGS[path]) load_type, load_path = path.split('::', 1) if load_type == 'pickle': return cls.from_dict(load_pickle(load_path)['llama_config']) elif load_type == 'json': with open_file(load_path, 'r') as fin: raw_config = fin.read() return cls.from_dict(json.loads(raw_config)) else: raise ValueError(f'Unsupported load config type: {load_type}') class FlaxVideoLLaMAPreTrainedModel(FlaxPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = VideoLLaMAConfig base_model_prefix = "transformer" module_class: nn.Module = None def __init__( self, config: VideoLLaMAConfig, input_shape: Tuple = (4, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, **kwargs, ): module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_cache(self, batch_size, max_length): # init input variables to retrieve cache input_ids = jnp.ones((batch_size, max_length)) attention_mask = jnp.ones_like(input_ids) segment_ids = jnp.zeros_like(input_ids) position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) vision_masks = jnp.ones((batch_size, max_length), dtype=bool) init_variables = self.module.init( jax.random.PRNGKey(0), input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False, init_cache=True ) return init_variables["cache"] def init_weights(self, rng, input_shape, params=None): # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) vision_masks = jnp.ones(input_ids.shape, dtype=bool) segment_ids = jnp.zeros_like(input_ids) position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init(rngs, input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False)["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params @add_start_docstrings_to_model_forward("") def __call__( self, input_ids, vision_masks, attention_mask=None, segment_ids=None, position_ids=None, params: dict = None, past_key_values: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict batch_size, sequence_length = input_ids.shape if position_ids is None: if past_key_values is not None: raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) if segment_ids is None: segment_ids = jnp.zeros((batch_size, sequence_length)) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False outputs = self.module.apply( inputs, jnp.array(input_ids, dtype="i4"), jnp.array(vision_masks, dtype="f4"), jnp.array(attention_mask, dtype="i4"), jnp.array(segment_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"), not train, False, output_attentions, output_hidden_states, return_dict, rngs=rngs, mutable=mutable, ) # add updated cache to model output if past_key_values is not None and return_dict: outputs, past_key_values = outputs outputs["past_key_values"] = unfreeze(past_key_values["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs, past_key_values = outputs outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] return outputs class FlaxVideoLLaMAModule(nn.Module): config: VideoLLaMAConfig dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype=jnp.float32 precision: Optional[Union[jax.lax.Precision, str]]=None def setup(self): self.embed_dim = self.config.hidden_size self.vte = nn.Embed( self.config.vision_vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), dtype=self.dtype, param_dtype=self.param_dtype, ) self.wte = nn.Embed( self.config.vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), dtype=self.dtype, param_dtype=self.param_dtype, ) self.dropout = nn.Dropout(rate=self.config.embd_pdrop) self.h = FlaxLLaMABlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision) self.ln_f = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype) def __call__( self, input_ids, vision_masks, attention_mask, segment_ids, position_ids, deterministic=True, init_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): input_ids = input_ids.astype("i4") if input_ids.shape[1] == 1: if self.config.sample_mode == 'text': input_embeds = self.wte(input_ids) elif self.config.sample_mode == 'vision': input_embeds = self.vte(input_ids) elif self.config.sample_mode == 'all': raise NotImplementedError else: raise ValueError(f"Invalid sample_mode: {self.config.sample_mode}") else: input_text_embeds = self.wte(jnp.where(vision_masks, 0, input_ids)) input_vision_embeds = self.vte(jnp.where(vision_masks, input_ids, 0)) vision_masks = vision_masks[..., None].astype("f4") # 1 is vision, 0 is text input_embeds = input_text_embeds * (1 - vision_masks) + input_vision_embeds * vision_masks hidden_states = self.dropout(input_embeds, deterministic=deterministic) outputs = self.h( hidden_states, attention_mask, segment_ids, position_ids=position_ids, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] hidden_states = self.ln_f(hidden_states) if output_hidden_states: all_hidden_states = outputs[1] + (hidden_states,) outputs = (hidden_states, all_hidden_states) + outputs[2:] else: outputs = (hidden_states,) + outputs[1:] if not return_dict: return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutput( last_hidden_state=hidden_states, hidden_states=outputs[1], attentions=outputs[-1], ) class FlaxVideoLLaMAForCausalLMModule(nn.Module): config: VideoLLaMAConfig dtype: jnp.dtype = jnp.float32 param_dtype: jnp.dtype=jnp.float32 precision: Optional[Union[jax.lax.Precision, str]]=None def setup(self): self.transformer = FlaxVideoLLaMAModule(self.config, dtype=self.dtype) self.vision_head = nn.Dense( self.config.vision_vocab_size, dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), precision=self.precision, ) self.lm_head = nn.Dense( self.config.vocab_size, dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), precision=self.precision, ) def __call__( self, input_ids, vision_masks, attention_mask=None, segment_ids=None, position_ids=None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): batch_size, seq_length = input_ids.shape if attention_mask is None: attention_mask = jnp.ones_like(input_ids) if segment_ids is None: segment_ids = jnp.zeros_like(input_ids) if position_ids is None: position_ids = jnp.broadcast_to( jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), (batch_size, seq_length) ) outputs = self.transformer( input_ids, vision_masks, attention_mask, segment_ids, position_ids, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] if self.config.tie_vision_embeddings: shared_kernel = self.transformer.variables["params"]["vte"]["embedding"].T vision_logits = self.vision_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) else: vision_logits = self.vision_head(hidden_states) if self.config.tie_word_embeddings: shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) else: lm_logits = self.lm_head(hidden_states) if self.config.sample_mode == 'all': if not return_dict: return (vision_logits, lm_logits,) + outputs[1:] return FlaxCausalLMOutput(logits=(vision_logits, lm_logits), hidden_states=outputs.hidden_states, attentions=outputs.attentions) elif self.config.sample_mode == 'vision': if not return_dict: return (vision_logits,) + outputs[1:] return FlaxCausalLMOutput(logits=vision_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) elif self.config.sample_mode == 'text': if not return_dict: return (lm_logits,) + outputs[1:] return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) else: raise ValueError(f"Invalid sample_mode: {self.config.sample_mode}") @add_start_docstrings("", "") class FlaxVideoLLaMAForCausalLM(FlaxVideoLLaMAPreTrainedModel): module_class = FlaxVideoLLaMAForCausalLMModule def prepare_inputs_for_generation( self, input_ids, max_length, attention_mask: Optional[jax.Array] = None, vision_masks = None ): # initializing the cache batch_size, seq_length = input_ids.shape past_key_values = self.init_cache(batch_size, max_length) extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) else: position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) return { "past_key_values": past_key_values, "attention_mask": extended_attention_mask, "position_ids": position_ids, "vision_masks": vision_masks } def update_inputs_for_generation(self, model_outputs, model_kwargs): return { "past_key_values": model_outputs.past_key_values, "position_ids": model_kwargs["position_ids"][:, -1:] + 1, "attention_mask": model_kwargs["attention_mask"], "vision_masks": model_kwargs["vision_masks"] } def _sample_vision( self, input_ids: None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, prng_key: Optional[jnp.ndarray] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None, logits_warper: Optional[FlaxLogitsProcessorList] = None, cfg_scales: jnp.ndarray = 1.0, trace: bool = True, params: Optional[Dict[str, jnp.ndarray]] = None, model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, ): # init values max_length = max_length if max_length is not None else self.generation_config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) batch_size, cur_len = input_ids.shape initial_len = cur_len eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None) pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32) cur_len = jnp.array(cur_len) # per batch-item holding current token in loop. sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32) sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0)) # per batch-item state bit indicating if sentence has finished. is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_) # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop # and pass it the `encoder_outputs`, which are part of the `model_kwargs`. model = self.decode if self.config.is_encoder_decoder else self # initialize model specific kwargs model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs) # initialize state state = SampleState( cur_len=cur_len, sequences=sequences, running_token=input_ids, is_sent_finished=is_sent_finished, prng_key=prng_key, model_kwargs=model_kwargs, ) def sample_search_cond_fn(state): """state termination condition fn.""" has_reached_max_length = state.cur_len == max_length all_sequence_finished = jnp.all(state.is_sent_finished) finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished) return ~finish_generation def sample_search_body_fn(state): """state update fn.""" prng_key, prng_key_next = jax.random.split(state.prng_key) model_outputs = model(state.running_token, params=params, **state.model_kwargs) logits = model_outputs.logits[:, -1] cond_logits, uncond_logits = jnp.split(logits, 2, axis=0) logits = uncond_logits + cfg_scales[:, None] * (cond_logits - uncond_logits) # apply min_length, ... logits = logits_processor(state.sequences, logits, state.cur_len) # apply top_p, top_k, temperature logits = logits_warper(logits, logits, state.cur_len) next_token = jax.random.categorical(prng_key, logits, axis=-1) next_token = jax.lax.cond( (state.cur_len - initial_len + 1) % 257 == 0, lambda: jnp.full_like(next_token, 8192), lambda: next_token ) next_token = jnp.concatenate([next_token, next_token], axis=0) #next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id) next_token = next_token[:, None] next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len)) next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs) return SampleState( cur_len=state.cur_len + 1, sequences=next_sequences, running_token=next_token, is_sent_finished=next_is_sent_finished, model_kwargs=next_model_kwargs, prng_key=prng_key_next, ) # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU if input_ids.shape[1] > 1: state = sample_search_body_fn(state) if not trace: state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state) else: state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state) return FlaxSampleOutput(sequences=state.sequences) def generate_vision( self, input_ids: jnp.ndarray, cfg_scales: jnp.ndarray, generation_config: Optional[GenerationConfig] = None, prng_key: Optional[jnp.ndarray] = None, trace: bool = True, params: Optional[Dict[str, jnp.ndarray]] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None, **kwargs, ): # Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() # priority: `generation_config` argument > `model.generation_config` (the default generation config) if generation_config is None: # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, # two conditions must be met # 1) the generation config must have been created from the model config (`_from_model_config` field); # 2) the generation config must have seen no modification since its creation (the hash is the same). if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash( self.generation_config ): new_generation_config = GenerationConfig.from_model_config(self.config) if new_generation_config != self.generation_config: warnings.warn( "You have modified the pretrained model configuration to control generation. This is a" " deprecated strategy to control generation and will be removed soon, in a future version." " Please use and modify the model generation configuration (see" " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" ) self.generation_config = new_generation_config generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList() # set init values prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: if model_kwargs.get("attention_mask") is None: logger.warning( "The attention mask and the pad token id were not set. As a consequence, you may observe " "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." ) eos_token_id = generation_config.eos_token_id if isinstance(eos_token_id, list): eos_token_id = eos_token_id[0] logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") generation_config.pad_token_id = eos_token_id if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder: raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.") # decoder-only models should use left-padding for generation (can't be checked with `trace=True`) if not self.config.is_encoder_decoder and not trace: if ( generation_config.pad_token_id is not None and jnp.sum(input_ids[:, -1] == generation_config.pad_token_id) > 0 ): logger.warning( "A decoder-only architecture is being used, but right-padding was detected! For correct " "generation results, please set `padding_side='left'` when initializing the tokenizer." ) batch_size = input_ids.shape[0] if self.config.is_encoder_decoder: # add encoder_outputs to model_kwargs if model_kwargs.get("encoder_outputs") is None: model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs) # prepare decoder_input_ids for generation input_ids = self._prepare_decoder_input_ids_for_generation( batch_size, decoder_start_token_id=generation_config.decoder_start_token_id, bos_token_id=generation_config.bos_token_id, model_kwargs=model_kwargs, ) # Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: # 20 is the default max_length of the generation config warnings.warn( f"Using the model-agnostic default `max_length` (={generation_config.max_length}) " "to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) elif generation_config.max_new_tokens is not None: if not has_default_max_length and generation_config.max_length is not None: logger.warning( f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " "Please refer to the documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger than" f" the maximum length ({generation_config.max_length})" ) if input_ids_seq_length >= generation_config.max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" logger.warning( f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" " increasing`max_new_tokens`." ) logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, logits_processor=logits_processor, ) if not generation_config.do_sample and generation_config.num_beams == 1: raise NotImplementedError elif generation_config.do_sample and generation_config.num_beams == 1: logits_warper = self._get_logits_warper(generation_config=generation_config) return self._sample_vision( input_ids, generation_config.max_length, generation_config.pad_token_id, generation_config.eos_token_id, prng_key, logits_warper=logits_warper, logits_processor=logits_processor, cfg_scales=cfg_scales, trace=trace, params=params, model_kwargs=model_kwargs, ) elif not generation_config.do_sample and generation_config.num_beams > 1: raise NotImplementedError else: raise NotImplementedError("`Beam sampling is currently not implemented.") ================================================ FILE: lwm/vqgan.py ================================================ from typing import Optional from functools import cached_property, partial import pickle import numpy as np import jax import jax.numpy as jnp import flax.linen as nn from flax import jax_utils from transformers.configuration_utils import PretrainedConfig from ml_collections import ConfigDict from tux import function_args_to_config, open_file class VQGAN: def __init__(self, vqgan_checkpoint, replicate=False): assert vqgan_checkpoint != '' self.replicate = replicate self.config = VQGANConfig.get_default_config() self.params = pickle.load(open_file(vqgan_checkpoint, 'rb')) if replicate: self.params = jax_utils.replicate(self.params) else: self.params = jax.jit(lambda x: x)(self.params) self.model = VQGANModel(self.config) def _wrap_fn(self, fn): if self.replicate: return jax.pmap(fn, devices=jax.local_devices()) else: return jax.jit(fn) @cached_property def _encode(self): def fn(pixel_values, params): return self.model.apply( {'params': params}, pixel_values, method=self.model.encode ) return partial(self._wrap_fn(fn), params=self.params) @cached_property def _decode(self): def fn(encoding, params): return self.model.apply( {'params': params}, encoding, method=self.model.decode ) return partial(self._wrap_fn(fn), params=self.params) def encode(self, pixel_values): return self._encode(pixel_values) def decode(self, encoding): return self._decode(encoding) class VQGANConfig(PretrainedConfig): model_type = "vqgan" def __init__( self, resolution=256, num_channels=3, hidden_channels=128, channel_mult=(1, 2, 2, 4, 6), num_res_blocks=2, attn_resolutions=(), no_attn_mid_block=True, z_channels=64, num_embeddings=8192, quantized_embed_dim=64, dropout=0.0, resample_with_conv=True, commitment_cost=0.25 ): self.resolution = resolution self.num_channels = num_channels self.hidden_channels = hidden_channels self.channel_mult = channel_mult self.num_res_blocks = num_res_blocks self.attn_resolutions = attn_resolutions self.no_attn_mid_block = no_attn_mid_block self.z_channels = z_channels self.num_embeddings = num_embeddings self.quantized_embed_dim = quantized_embed_dim self.dropout = dropout self.resample_with_conv = resample_with_conv self.commitment_cost = commitment_cost @classmethod def get_default_config(cls, updates=None): config = function_args_to_config(cls.__init__) if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) config.num_resolutions = len(config.channel_mult) return config @classmethod def load_config(cls, path): return cls.get_default_config(cls) class VQGANModel(nn.Module): config: VQGANConfig def setup(self): self.encoder = Encoder(self.config) self.decoder = Decoder(self.config) self.quantize = VectorQuantizer( self.config.num_embeddings, self.config.quantized_embed_dim ) self.quant_conv = nn.Conv(self.config.quantized_embed_dim, [1, 1]) self.post_quant_conv = nn.Conv(self.config.z_channels, [1, 1]) def encode(self, pixel_values): T = None if len(pixel_values.shape) == 5: # video T = pixel_values.shape[1] pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:]) hidden_states = self.encoder(pixel_values) hidden_states = self.quant_conv(hidden_states) quantized_states, codebook_indices = self.quantize(hidden_states) if T is not None: quantized_states = quantized_states.reshape(-1, T, *quantized_states.shape[1:]) codebook_indices = codebook_indices.reshape(-1, T, *codebook_indices.shape[1:]) return quantized_states, codebook_indices def decode(self, encoding, is_codebook_indices=True): if is_codebook_indices: encoding = self.quantize(None, encoding) T = None if len(encoding.shape) == 5: T = encoding.shape[1] encoding = encoding.reshape(-1, *encoding.shape[2:]) hidden_states = self.post_quant_conv(encoding) reconstructed_pixel_values = self.decoder(hidden_states) if T is not None: reconstructed_pixel_values = reconstructed_pixel_values.reshape(-1, T, *reconstructed_pixel_values.shape[1:]) return jnp.clip(reconstructed_pixel_values, -1, 1) def __call__(self, pixel_values): encoding = self.encode(pixel_values)[1] recon = self.decode(encoding) return recon class Encoder(nn.Module): config: VQGANConfig @nn.compact def __call__(self, pixel_values): assert pixel_values.shape[1] == pixel_values.shape[2] == self.config.resolution, pixel_values.shape hidden_states = nn.Conv(self.config.hidden_channels, [3, 3])(pixel_values) for i_level in range(self.config.num_resolutions): hidden_states = DownsamplingBlock(self.config, i_level)(hidden_states) hidden_states = MidBlock( self.config, self.config.no_attn_mid_block, self.config.dropout )(hidden_states) hidden_states = nn.GroupNorm()(hidden_states) hidden_states = nn.silu(hidden_states) hidden_states = nn.Conv(self.config.z_channels, [3, 3])(hidden_states) return hidden_states class Decoder(nn.Module): config: VQGANConfig @nn.compact def __call__(self, hidden_states): hidden_states = nn.Conv( self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1], [3, 3] )(hidden_states) hidden_states = MidBlock( self.config, self.config.no_attn_mid_block, self.config.dropout )(hidden_states) for i_level in reversed(range(self.config.num_resolutions)): hidden_states = UpsamplingBlock(self.config, i_level)(hidden_states) hidden_states = nn.GroupNorm()(hidden_states) hidden_states = nn.silu(hidden_states) hidden_states = nn.Conv(self.config.num_channels, [3, 3])(hidden_states) return hidden_states class VectorQuantizer(nn.Module): n_e: int e_dim: int @nn.compact def __call__(self, z, encoding_indices=None): def quantize(encoding_indices): w = jax.device_put(embeddings) return w[(encoding_indices,)] embeddings = self.param( 'embeddings', lambda rng, shape, dtype: jax.random.uniform( rng, shape, dtype, minval=-1.0 / self.n_e, maxval=1.0 / self.n_e ), [self.n_e, self.e_dim], jnp.float32 ) if encoding_indices is not None: return quantize(encoding_indices) z_flattened = z.reshape(-1, z.shape[-1]) d = jnp.sum(z_flattened ** 2, axis=1, keepdims=True) + \ jnp.sum(embeddings.T ** 2, axis=0, keepdims=True) - \ 2 * jnp.einsum('bd,nd->bn', z_flattened, embeddings) min_encoding_indices = jnp.argmin(d, axis=1) z_q = quantize(min_encoding_indices) z_q = jnp.reshape(z_q, z.shape) z_q = z + jax.lax.stop_gradient(z_q - z) encodings_one_hot = jax.nn.one_hot(min_encoding_indices, num_classes=self.n_e) assert len(encodings_one_hot.shape) == 2 min_encoding_indices = jnp.reshape(min_encoding_indices, z.shape[:-1]) return z_q, min_encoding_indices class DownsamplingBlock(nn.Module): config: VQGANConfig block_idx: int @nn.compact def __call__(self, hidden_states): block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx] for _ in range(self.config.num_res_blocks): hidden_states = ResnetBlock( block_out, dropout_prob=self.config.dropout )(hidden_states) if hidden_states.shape[1] in self.config.attn_resolutions: hidden_states = AttnBlock()(hidden_states) if self.block_idx != self.config.num_resolutions - 1: hidden_states = Downsample(self.config.resample_with_conv)(hidden_states) return hidden_states class ResnetBlock(nn.Module): out_channels: Optional[int] = None use_conv_shortcut: bool = False dropout_prob: float = 0.0 @nn.compact def __call__(self, hidden_states): out_channels = self.out_channels or hidden_states.shape[-1] residual = hidden_states hidden_states = nn.GroupNorm()(hidden_states) hidden_states = nn.silu(hidden_states) hidden_states = nn.Conv(out_channels, [3, 3])(hidden_states) hidden_states = nn.GroupNorm()(hidden_states) hidden_states = nn.silu(hidden_states) hidden_states = nn.Dropout(self.dropout_prob, deterministic=True)(hidden_states) hidden_states = nn.Conv(out_channels, [3, 3])(hidden_states) if out_channels != residual.shape[-1]: if self.use_conv_shortcut: residual = nn.Conv(out_channels, [3, 3])(residual) else: residual = nn.Conv(out_channels, [1, 1])(residual) return hidden_states + residual class AttnBlock(nn.Module): @nn.compact def __call__(self, hidden_states): residual = hidden_states hidden_states = nn.GroupNorm()(hidden_states) query = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states) key = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states) value = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states) query, key, value = map( lambda x: x.reshape(x.shape[0], -1, x.shape[-1]), [query, key, value] ) attn_weights = jnp.einsum("bqd,bkd->bqk", query, key) attn_weights *= hidden_states.shape[-1] ** -0.5 attn_weights = jax.nn.softmax(attn_weights, axis=-1) hidden_states = jnp.einsum("bqk,bkd->bqd", attn_weights, value) hidden_states = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states) return hidden_states + residual class Downsample(nn.Module): with_conv: bool @nn.compact def __call__(self, hidden_states): if self.with_conv: hidden_states = jnp.pad( hidden_states, [(0, 0), (0, 1), (0, 1), (0, 0)] ) hidden_states = nn.Conv( hidden_states.shape[-1], [3, 3], strides=[2, 2], padding="VALID" )(hidden_states) else: hidden_states = nn.avg_pool(hidden_states, [2, 2], [2, 2]) return hidden_states class Upsample(nn.Module): with_conv: bool @nn.compact def __call__(self, hidden_states): B, H, W, C = hidden_states.shape hidden_states = jax.image.resize( hidden_states, (B, H * 2, W * 2, C), method="nearest" ) if self.with_conv: hidden_states = nn.Conv(hidden_states.shape[-1], [3, 3])(hidden_states) return hidden_states class UpsamplingBlock(nn.Module): config: VQGANConfig block_idx: int @nn.compact def __call__(self, hidden_states): block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx] for _ in range(self.config.num_res_blocks + 1): hidden_states = ResnetBlock( block_out, dropout_prob=self.config.dropout )(hidden_states) if hidden_states.shape[1] in self.config.attn_resolutions: hidden_states = AttnBlock()(hidden_states) if self.block_idx != 0: hidden_states = Upsample(self.config.resample_with_conv)(hidden_states) return hidden_states class MidBlock(nn.Module): config: VQGANConfig no_attn: bool dropout: float @nn.compact def __call__(self, hidden_states): hidden_states = ResnetBlock(dropout_prob=self.dropout)(hidden_states) if not self.no_attn: hidden_states = AttnBlock()(hidden_states) hidden_states = ResnetBlock(dropout_prob=self.dropout)(hidden_states) return hidden_states ================================================ FILE: scripts/create_needle_data.py ================================================ import os import argparse import json from tqdm import tqdm from datasets import load_dataset parser = argparse.ArgumentParser() parser.add_argument("--output_path", type=str, default="data/pg19.jsonl") args = parser.parse_args() os.makedirs(os.path.dirname(args.output_path), exist_ok=True) dset = load_dataset("pg19")["train"] with open(args.output_path, "w") as f: for elem in tqdm(dset): data = {"text": elem["text"]} f.write(f"{json.dumps(data)}\n") ================================================ FILE: scripts/eval_needle.py ================================================ from absl.app import run import time import json import math import os from tqdm import tqdm import random from functools import cached_property import numpy as np import jax from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as PS import gcsfs import tiktoken from transformers import GenerationConfig, AutoTokenizer from tux import ( define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig, set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng, match_partition_rules, make_shard_and_gather_fns, with_sharding_constraint, tree_apply, open_file ) from lwm.llama import LLaMAConfig, FlaxLLaMAForCausalLM FLAGS, FLAGS_DEF = define_flags_with_default( haystack_file="", max_tokens_per_batch=2000000, output_file="results.json", context_lengths_min=1000, context_lengths_max=32000, n_context_length_intervals=3, n_document_depth_intervals=3, n_rounds=2, seed=1234, mesh_dim='1,-1,1,1', dtype='fp32', load_llama_config='', update_llama_config='', load_checkpoint='', tokenizer='LargeWorldModel/LWM-Text-1M', checkpointer=StreamingCheckpointer.get_default_config(), llama=LLaMAConfig.get_default_config(), jax_distributed=JaxDistributedConfig.get_default_config(), ) class LLMNeedleHaystackTester: OURS_TEMPLATE = "You are a helpful assistant. USER: {context} {question} Don't give information outside the document or repeat your findings. Keep your response short and direct. ASSISTANT: " RANDOM_NEEDLE_CITIES = [ 'Chicago', 'Yangon', 'Antananarivo', 'Colombo', 'Almaty', 'Sydney', 'Chicago', 'Mexico City', 'Seattle', 'Lagos', 'Amsterdam', 'Belgrade', 'Cairo', 'Baghdad', 'Damascus', 'Kigali', 'Dakar', 'Dakar', 'Sofia', 'Kigali', 'Victoria', 'Tashkent', 'Mumbai', 'Barcelona', 'Almaty', 'Amman', 'Toronto', 'Bratislava', 'Johannesburg', 'Thimphu', 'Bangkok', 'Santiago', 'Cairo', 'San Francisco', 'Lagos', 'Amsterdam', 'Paris', 'Rabat', 'Santiago', 'Copenhagen', 'Madrid', 'Kigali', 'Ho Chi Minh City', 'Sarajevo', 'Delhi', 'Istanbul', 'Ho Chi Minh City', 'Khartoum', 'Helsinki', 'Doha', 'Istanbul', 'Kuala Lumpur', 'Budapest', 'Shanghai', 'Moscow', 'Los Angeles', 'Oslo', 'Johannesburg', 'Berlin', 'Bangalore', 'Tokyo', 'Melbourne', 'Barcelona', 'Chicago', 'Port Louis', 'Lisbon', 'Nairobi', 'Kampala', 'Lima', 'Maputo', 'Vancouver', 'Dubai', 'Khartoum', 'Jakarta', 'Madrid', 'Yerevan', 'Beirut', 'Athens', 'Chicago', 'Paris', 'Bucharest', 'Copenhagen', 'Brussels', 'Damascus', 'Seattle', 'Los Angeles', 'Yerevan', 'Victoria', 'Tunis', 'Astana', 'Seoul', 'Buenos Aires', 'Bangkok', 'Colombo', 'Brussels', 'Khartoum', 'Doha', 'San Francisco', 'Vienna', 'Jakarta' ] def __init__(self, needle="", haystack_file="", retrieval_question="What is the special magic {} number?", results_version = 1, rnd_number_digits = 7, context_lengths_min = 1000, context_lengths_max = 126000, context_lengths_num_intervals = 10, document_depth_percent_min = 0, document_depth_percent_max = 100, document_depth_percent_intervals = 10, document_depth_percent_interval_type = "linear", save_results = False, final_context_length_buffer = 200, print_ongoing_status = True): needle="\nThe special magic {city} number is: {rnd_number}\n" self.needle = needle if not needle or not haystack_file or not retrieval_question: raise ValueError("Needle, haystack, and retrieval_question must be provided.") self.rnd_number_digits = rnd_number_digits self.context_lengths_num_intervals = context_lengths_num_intervals self.document_depth_percent_intervals = document_depth_percent_intervals self.haystack_file = haystack_file self.retrieval_question = retrieval_question self.results_version = results_version self.save_results = save_results self.final_context_length_buffer = final_context_length_buffer self.print_ongoing_status = print_ongoing_status self.testing_results = [] self.context_lengths = np.round(np.linspace(context_lengths_min, context_lengths_max, num=context_lengths_num_intervals, endpoint=True)).astype(int) if document_depth_percent_interval_type == 'linear': self.document_depth_percents = np.round(np.linspace(document_depth_percent_min, document_depth_percent_max, num=document_depth_percent_intervals, endpoint=True)).astype(int) elif document_depth_percent_interval_type == 'sigmoid': self.document_depth_percents = [self.logistic(x) for x in np.linspace(document_depth_percent_min, document_depth_percent_max, document_depth_percent_intervals)] else: raise ValueError(f"Unsupported document_depth_percent_interval_type: {document_depth_percent_interval_type}") self.model = Sampler() self.enc = AutoTokenizer.from_pretrained(FLAGS.tokenizer) self.enc_tiktoken = tiktoken.encoding_for_model("gpt-4-1106-preview") def generate_random_number(self, num_digits): lower_bound = 10**(num_digits - 1) upper_bound = 10**num_digits - 1 return random.randint(lower_bound, upper_bound) def logistic(self, x, L=100, x0=50, k=.1): if x == 0: return 0 if x == 100: return 100 return np.round(L / (1 + np.exp(-k * (x - x0))), 3) def read_context_files(self, n): max_context_length = max(self.context_lengths) contexts = [] f = open_file(self.haystack_file, 'r') for _ in range(n): context = "" toks = 0 while toks < max_context_length: text = json.loads(f.readline())['text'] context += text toks += len(self.enc.encode(text)) contexts.append(context) return contexts def encode_and_trim(self, context, context_length): tokens = self.enc.encode(context) if len(tokens) > context_length: context = self.enc.decode(tokens[:context_length]) return context def create_contexts(self, needle_rnd_number, insert_needle, random_city, trim_context, context_length, depth_percent, seed): if self.save_results: if self.result_exists(context_length, depth_percent): return needle = self.needle.format(city=random_city, rnd_number=needle_rnd_number) question = self.retrieval_question.format(random_city) if not insert_needle: needle = " " #replace needle with a space context = self.generate_context(needle, trim_context, context_length, depth_percent) results = { 'context' : context, 'context_length' : int(context_length), 'depth_percent' : float(depth_percent), 'needle' : needle, 'question' : question, 'insert_needle' : insert_needle, 'needle_rnd_number' : needle_rnd_number, 'seed': seed, } return results def insert_needle(self, needle, context, depth_percent, context_length): tokens_needle = self.enc_tiktoken.encode(needle) tokens_context = self.enc_tiktoken.encode(context) # Reducing the context length by 150 buffer. This is to account for system message, the user question, and response. context_length -= self.final_context_length_buffer # If your context + needle are longer than the context length (which it will be), then reduce tokens from the context by the needle length if len(tokens_context) + len(tokens_needle) > context_length: tokens_context = tokens_context[:context_length - len(tokens_needle)] if depth_percent == 100: # If your depth percent is 100 (which means your needle is the last thing in the doc), throw it at the end tokens_new_context = tokens_context + tokens_needle else: # Go get the position (in terms of tokens) to insert your needle insertion_point = int(len(tokens_context) * (depth_percent / 100)) # tokens_new_context represents the tokens before the needle tokens_new_context = tokens_context[:insertion_point] # We want to make sure that we place our needle at a sentence break so we first see what token a '.' is period_tokens = self.enc_tiktoken.encode('.') # Then we iteration backwards until we find the first period while tokens_new_context and tokens_new_context[-1] not in period_tokens: insertion_point -= 1 tokens_new_context = tokens_context[:insertion_point] # Once we get there, then add in your needle, and stick the rest of your context in on the other end. # Now we have a needle in a haystack tokens_new_context += tokens_needle + tokens_context[insertion_point:] # Convert back to a string and return it new_context = self.enc_tiktoken.decode(tokens_new_context) return new_context def generate_context(self, needle, trim_context, context_length, depth_percent): context = self.insert_needle(needle, trim_context, depth_percent, context_length) return context def compute_max_input_length(self, context_length, buffer=1024): block_size = self.model.block_size context_length += buffer context_length = math.ceil(context_length / block_size) * block_size return int(context_length) def run_test(self): fs = gcsfs.GCSFileSystem() contexts = [] template = self.OURS_TEMPLATE def _key_from_result(result): return (result['context_length'], result['depth_percent'], result['seed']) results = [] completed = set() def exists(fname): if fname.startswith('gs://'): return fs.exists(fname) else: return os.path.exists(fname) if exists(FLAGS.output_file): with open_file(FLAGS.output_file, 'r') as f: results = json.load(f) completed = set([_key_from_result(result) for result in results]) print('completed', len(completed)) full_contexts = self.read_context_files(FLAGS.n_rounds) full_tokens = [self.enc.encode(full_context) for full_context in tqdm(full_contexts)] start = time.time() for context_length in self.context_lengths: trim_contexts = [self.enc.decode(full_token[:context_length]) for full_token in tqdm(full_tokens)] max_input_length = self.compute_max_input_length(context_length) contexts = [] for depth_percent in self.document_depth_percents: for i in range(FLAGS.n_rounds): if (int(context_length), float(depth_percent), i) in completed: continue random_city = random.choice(LLMNeedleHaystackTester.RANDOM_NEEDLE_CITIES) insert_needle = True needle_rnd_number = str(self.generate_random_number(self.rnd_number_digits)) print("context length: " + str(context_length)) print("depth_percent : " + str(depth_percent)) context = self.create_contexts(needle_rnd_number, insert_needle, random_city, trim_contexts[i], context_length, depth_percent, i) contexts.append(context) if len(contexts) == 0: continue B = FLAGS.max_tokens_per_batch / (max_input_length + self.model.block_size) B = int(B / self.model.data_dim) * self.model.data_dim if B < self.model.data_dim: B = self.model.data_dim elif B > len(contexts): B = int(math.ceil(len(contexts) / self.model.data_dim) * self.model.data_dim) if len(contexts) % B == 0: n_pad = 0 else: n_pad = B - len(contexts) % B for _ in range(n_pad): contexts.insert(0, contexts[0]) pbar = tqdm(total=len(contexts)) for i in range(0, len(contexts), B): contexts_i = contexts[i:i + B] prompts = [ template.format(context=context['context'], question=context['question']) for context in contexts_i ] outs = self.model(prompts, max_input_length) for j, (context, out) in enumerate(zip(contexts_i, outs)): if i + j < n_pad: continue results.append({ 'context_length': context['context_length'], 'depth_percent': context['depth_percent'], 'response': out, 'answer': context['needle_rnd_number'], 'correct': context['needle_rnd_number'] in out, 'seed': context['seed'], }) print(results[-1]) if jax.process_index() == 0: with open_file(FLAGS.output_file, 'w') as f: json.dump(results, f) pbar.update(len(contexts_i)) pbar.close() print('elapsed', time.time() - start) print('done') def print_start_test_summary(self): print ("\n") print ("Starting Needle In A Haystack Testing...") print (f"- Context Lengths: {len(self.context_lengths)}, Min: {min(self.context_lengths)}, Max: {max(self.context_lengths)}") print (f"- Document Depths: {len(self.document_depth_percents)}, Min: {min(self.document_depth_percents)}%, Max: {max(self.document_depth_percents)}%") print (f"- Needle: {self.needle.strip()}") print ("\n\n") def start_test(self): if self.print_ongoing_status: self.print_start_test_summary() self.run_test() class Sampler: def __init__(self): self.mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim) self.prefix_tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer, truncation_side='left', padding_side='left') self.tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer) self.sharded_rng = next_rng() self._load_model() @property def block_size(self): # return 2 * max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) return max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) * self.mesh.shape['sp'] @property def data_dim(self): return self.mesh.shape['dp'] * self.mesh.shape['fsdp'] def _load_model(self): if FLAGS.load_llama_config != '': llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config) updates = LLaMAConfig(**FLAGS.llama) llama_config.update(dict( scan_attention=updates.scan_attention, scan_mlp=updates.scan_mlp, scan_query_chunk_size=updates.scan_query_chunk_size, scan_key_chunk_size=updates.scan_key_chunk_size, scan_mlp_chunk_size=updates.scan_mlp_chunk_size, scan_layers=updates.scan_layers, param_scan_axis=updates.param_scan_axis, )) else: llama_config = LLaMAConfig(**FLAGS.llama) if FLAGS.update_llama_config != '': llama_config.update(dict(eval(FLAGS.update_llama_config))) llama_config.update(dict( bos_token_id=self.tokenizer.bos_token_id, eos_token_id=self.tokenizer.eos_token_id, )) llama_config.update(dict(mesh_dim=FLAGS.mesh_dim)) self.config = llama_config with jax.default_device(jax.devices("cpu")[0]): _, self.params = StreamingCheckpointer.load_trainstate_checkpoint( FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30 ) self.model = FlaxLLaMAForCausalLM( llama_config, input_shape=(512, self.block_size), seed=FLAGS.seed, _do_init=False, dtype=get_float_dtype_by_name(FLAGS.dtype), ) self.model_ps = match_partition_rules( LLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params ) shard_fns, _ = make_shard_and_gather_fns( self.model_ps, get_float_dtype_by_name(FLAGS.dtype) ) with self.mesh: self.params = tree_apply(shard_fns, self.params) @cached_property def _forward_generate(self): def fn(params, rng, batch): batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp')) rng_generator = JaxRNG(rng) output = self.model.generate( batch['input_ids'], attention_mask=batch['attention_mask'], params=params['params'], prng_key=rng_generator(), generation_config=GenerationConfig( max_new_tokens=self.block_size, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, temperature=0., do_sample=False, num_beams=1, top_k=50, top_p=1.0, ) ).sequences[:, batch['input_ids'].shape[1]:] return output, rng_generator() return pjit( fn, in_shardings=(self.model_ps, PS(), PS()), out_shardings=(PS(), PS()) ) def __call__(self, prompts, max_input_length): inputs = self.prefix_tokenizer( prompts, padding='max_length', truncation=True, max_length=max_input_length, return_tensors='np' ) batch = dict( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask ) with self.mesh: output, self.sharded_rng = self._forward_generate( self.params, self.sharded_rng, batch ) output = jax.device_get(output) output_text = [] for text in list(self.tokenizer.batch_decode(output, skip_special_tokens=True)): if self.tokenizer.eos_token in text: text = text.split(self.tokenizer.eos_token, maxsplit=1)[0] output_text.append(text) return output_text def main(argv): JaxDistributedConfig.initialize(FLAGS.jax_distributed) set_random_seed(FLAGS.seed) ht = LLMNeedleHaystackTester( haystack_file=FLAGS.haystack_file, context_lengths_min=FLAGS.context_lengths_min, context_lengths_max=FLAGS.context_lengths_max, context_lengths_num_intervals=FLAGS.n_context_length_intervals, document_depth_percent_intervals=FLAGS.n_document_depth_intervals, ) ht.start_test() if __name__ == "__main__": run(main) ================================================ FILE: scripts/eval_needle_multi.py ================================================ from absl.app import run import glob import time import json import math import os from tqdm import tqdm import random from functools import cached_property import numpy as np import jax from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as PS import gcsfs import tiktoken from transformers import GenerationConfig, AutoTokenizer from tux import ( define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig, set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng, match_partition_rules, make_shard_and_gather_fns, with_sharding_constraint, tree_apply, open_file ) from lwm.llama import LLaMAConfig, FlaxLLaMAForCausalLM FLAGS, FLAGS_DEF = define_flags_with_default( haystack_file="", max_tokens_per_batch=2000000, output_file="results.json", context_lengths_min=1000, context_lengths_max=32000, n_context_length_intervals=3, n_document_depth_intervals=3, n_rounds=2, n_needles_total=4, n_needles_retrieve=4, seed=1234, mesh_dim='1,-1,1,1', dtype='fp32', load_llama_config='', update_llama_config='', load_checkpoint='', tokenizer='LargeWorldModel/LWM-Text-1M', checkpointer=StreamingCheckpointer.get_default_config(), llama=LLaMAConfig.get_default_config(), jax_distributed=JaxDistributedConfig.get_default_config(), ) class LLMNeedleHaystackTester: OURS_TEMPLATE = "You are a helpful assistant. USER: {context} {question} Don't give information outside the document. ASSISTANT: " RANDOM_NEEDLE_CITIES = [ 'Chicago', 'Yangon', 'Antananarivo', 'Colombo', 'Almaty', 'Sydney', 'Chicago', 'Mexico City', 'Seattle', 'Lagos', 'Amsterdam', 'Belgrade', 'Cairo', 'Baghdad', 'Damascus', 'Kigali', 'Dakar', 'Dakar', 'Sofia', 'Kigali', 'Victoria', 'Tashkent', 'Mumbai', 'Barcelona', 'Almaty', 'Amman', 'Toronto', 'Bratislava', 'Johannesburg', 'Thimphu', 'Bangkok', 'Santiago', 'Cairo', 'San Francisco', 'Lagos', 'Amsterdam', 'Paris', 'Rabat', 'Santiago', 'Copenhagen', 'Madrid', 'Kigali', 'Ho Chi Minh City', 'Sarajevo', 'Delhi', 'Istanbul', 'Ho Chi Minh City', 'Khartoum', 'Helsinki', 'Doha', 'Istanbul', 'Kuala Lumpur', 'Budapest', 'Shanghai', 'Moscow', 'Los Angeles', 'Oslo', 'Johannesburg', 'Berlin', 'Bangalore', 'Tokyo', 'Melbourne', 'Barcelona', 'Chicago', 'Port Louis', 'Lisbon', 'Nairobi', 'Kampala', 'Lima', 'Maputo', 'Vancouver', 'Dubai', 'Khartoum', 'Jakarta', 'Madrid', 'Yerevan', 'Beirut', 'Athens', 'Chicago', 'Paris', 'Bucharest', 'Copenhagen', 'Brussels', 'Damascus', 'Seattle', 'Los Angeles', 'Yerevan', 'Victoria', 'Tunis', 'Astana', 'Seoul', 'Buenos Aires', 'Bangkok', 'Colombo', 'Brussels', 'Khartoum', 'Doha', 'San Francisco', 'Vienna', 'Jakarta' ] def __init__(self, needle="", haystack_file="", retrieval_question="What are the special magic numbers for {}?", results_version = 1, rnd_number_digits = 7, context_lengths_min = 1000, context_lengths_max = 126000, context_lengths_num_intervals = 10, document_depth_percent_min = 0, document_depth_percent_max = 100, document_depth_percent_intervals = 10, document_depth_percent_interval_type = "linear", save_results = False, final_context_length_buffer = 200, print_ongoing_status = True): needle="\nThe special magic {city} number is: {rnd_number}\n" self.needle = needle if not needle or not haystack_file or not retrieval_question: raise ValueError("Needle, haystack, and retrieval_question must be provided.") self.rnd_number_digits = rnd_number_digits self.context_lengths_num_intervals = context_lengths_num_intervals self.document_depth_percent_intervals = document_depth_percent_intervals self.haystack_file = haystack_file self.retrieval_question = retrieval_question self.results_version = results_version self.save_results = save_results self.final_context_length_buffer = final_context_length_buffer self.print_ongoing_status = print_ongoing_status self.testing_results = [] self.context_lengths = np.round(np.linspace(context_lengths_min, context_lengths_max, num=context_lengths_num_intervals, endpoint=True)).astype(int) self.context_lengths = self.context_lengths.tolist() if document_depth_percent_interval_type == 'linear': self.document_depth_percents = np.round(np.linspace(document_depth_percent_min, document_depth_percent_max, num=document_depth_percent_intervals, endpoint=True)).astype(int) elif document_depth_percent_interval_type == 'sigmoid': self.document_depth_percents = [self.logistic(x) for x in np.linspace(document_depth_percent_min, document_depth_percent_max, document_depth_percent_intervals)] else: raise ValueError(f"Unsupported document_depth_percent_interval_type: {document_depth_percent_interval_type}") self.document_depth_percents = self.document_depth_percents.tolist() self.model = Sampler() self.enc = AutoTokenizer.from_pretrained(FLAGS.tokenizer) self.enc_tiktoken = tiktoken.encoding_for_model("gpt-4-1106-preview") def generate_random_number(self, num_digits): lower_bound = 10**(num_digits - 1) upper_bound = 10**num_digits - 1 return random.randint(lower_bound, upper_bound) def logistic(self, x, L=100, x0=50, k=.1): if x == 0: return 0 if x == 100: return 100 return np.round(L / (1 + np.exp(-k * (x - x0))), 3) def read_context_files(self, n): max_context_length = max(self.context_lengths) contexts = [] f = open_file(self.haystack_file, 'r') for i in range(n): context = "" while len(self.enc.encode(context)) < max_context_length: context += json.loads(f.readline())['text'] contexts.append(context) return contexts def encode_and_trim(self, context, context_length): tokens = self.enc.encode(context) if len(tokens) > context_length: context = self.enc.decode(tokens[:context_length]) return context def create_contexts(self, needles_info, random_cities_retrieve, context, context_length, seed): assert all([random_city in needles_info for random_city in random_cities_retrieve]) for random_city, (needle_rnd_number, depth_percent) in needles_info.items(): context = self.generate_context( self.needle.format(city=random_city, rnd_number=needle_rnd_number), context, context_length, depth_percent ) if len(random_cities_retrieve) == 1: question = f"What is the special magic number for {random_cities_retrieve[0]}?" else: q = ', '.join(random_cities_retrieve[:-1]) + ', and ' + random_cities_retrieve[-1] question = self.retrieval_question.format(q) results = { 'context' : context, 'context_length' : int(context_length), 'needles_info': needles_info, 'question' : question, 'cities_to_retrieve' : random_cities_retrieve, 'seed': seed, } return results def insert_needle(self, needle, context, depth_percent, context_length): tokens_needle = self.enc_tiktoken.encode(needle) tokens_context = self.enc_tiktoken.encode(context) # Reducing the context length by 150 buffer. This is to account for system message, the user question, and response. context_length -= self.final_context_length_buffer # If your context + needle are longer than the context length (which it will be), then reduce tokens from the context by the needle length if len(tokens_context) + len(tokens_needle) > context_length: tokens_context = tokens_context[:context_length - len(tokens_needle)] if depth_percent == 100: # If your depth percent is 100 (which means your needle is the last thing in the doc), throw it at the end tokens_new_context = tokens_context + tokens_needle else: # Go get the position (in terms of tokens) to insert your needle insertion_point = int(len(tokens_context) * (depth_percent / 100)) # tokens_new_context represents the tokens before the needle tokens_new_context = tokens_context[:insertion_point] # We want to make sure that we place our needle at a sentence break so we first see what token a '.' is period_tokens = self.enc_tiktoken.encode('.') # Then we iteration backwards until we find the first period while tokens_new_context and tokens_new_context[-1] not in period_tokens: insertion_point -= 1 tokens_new_context = tokens_context[:insertion_point] # Once we get there, then add in your needle, and stick the rest of your context in on the other end. # Now we have a needle in a haystack tokens_new_context += tokens_needle + tokens_context[insertion_point:] # Convert back to a string and return it new_context = self.enc_tiktoken.decode(tokens_new_context) return new_context def generate_context(self, needle, trim_context, context_length, depth_percent): context = self.insert_needle(needle, trim_context, depth_percent, context_length) return context def compute_max_input_length(self, context_length, buffer=1024): block_size = self.model.block_size context_length += buffer # context_length = 2 ** math.ceil(math.log2(context_length)) context_length = math.ceil(context_length / block_size) * block_size return int(context_length) def run_test(self): fs = gcsfs.GCSFileSystem() contexts = [] template = self.OURS_TEMPLATE def _key_from_result(result): return (result['context_length'], result['depth_percent'], result['seed']) results = [] completed = set() def exists(fname): if fname.startswith('gs://'): return fs.exists(fname) else: return os.path.exists(fname) if exists(FLAGS.output_file): with open_file(FLAGS.output_file, 'r') as f: results = json.load(f) completed = set([_key_from_result(result) for result in results]) print('completed', len(completed)) full_contexts = self.read_context_files(FLAGS.n_rounds) full_tokens = [self.enc.encode(full_context) for full_context in full_contexts] start = time.time() for context_length in self.context_lengths: trim_contexts = [self.enc.decode(full_token[:context_length]) for full_token in full_tokens] max_input_length = self.compute_max_input_length(context_length) contexts = [] for i in range(FLAGS.n_rounds): if (int(context_length), i) in completed: continue random_cities = random.sample(LLMNeedleHaystackTester.RANDOM_NEEDLE_CITIES, FLAGS.n_needles_total) document_depths = random.sample(self.document_depth_percents, FLAGS.n_needles_total) random_cities_retrieve = random.sample(random_cities, FLAGS.n_needles_retrieve) needles_info = {} for random_city, depth_percent in zip(random_cities, document_depths): needles_info[random_city] = ( str(self.generate_random_number(self.rnd_number_digits)), depth_percent ) context = self.create_contexts(needles_info, random_cities_retrieve, trim_contexts[i], context_length, i) contexts.append(context) if len(contexts) == 0: continue B = FLAGS.max_tokens_per_batch / (max_input_length + self.model.block_size) B = int(B / self.model.data_dim) * self.model.data_dim if B < self.model.data_dim: B = self.model.data_dim elif B > len(contexts): B = int(math.ceil(len(contexts) / self.model.data_dim) * self.model.data_dim) n_pad = B - len(contexts) % B for _ in range(n_pad): contexts.insert(0, contexts[0]) pbar = tqdm(total=len(contexts)) for i in range(0, len(contexts), B): contexts_i = contexts[i:i + B] prompts = [ template.format(context=context['context'], question=context['question']) for context in contexts_i ] outs = self.model(prompts, max_input_length) for j, (context, out) in enumerate(zip(contexts_i, outs)): if i + j < n_pad: continue rnd_nums_to_retrieve = [ context['needles_info'][city][0] for city in context['cities_to_retrieve'] ] results.append({ 'context_length': context['context_length'], 'needles_info': context['needles_info'], 'question': context['question'], 'answer': rnd_nums_to_retrieve, 'response': out, 'correct': [rnd_num in out for rnd_num in rnd_nums_to_retrieve], 'seed': context['seed'], }) print(results[-1]['correct'], out, rnd_nums_to_retrieve) if jax.process_index() == 0: with open_file(FLAGS.output_file, 'w') as f: json.dump(results, f) pbar.update(len(contexts_i)) pbar.close() print('elapsed', time.time() - start) print('done') def print_start_test_summary(self): print ("\n") print ("Starting Needle In A Haystack Testing...") print (f"- Context Lengths: {len(self.context_lengths)}, Min: {min(self.context_lengths)}, Max: {max(self.context_lengths)}") print (f"- Document Depths: {len(self.document_depth_percents)}, Min: {min(self.document_depth_percents)}%, Max: {max(self.document_depth_percents)}%") print (f"- Needle: {self.needle.strip()}") print ("\n\n") def start_test(self): if self.print_ongoing_status: self.print_start_test_summary() self.run_test() class Sampler: def __init__(self): self.mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim) self.prefix_tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer, truncation_side='left', padding_side='left') self.tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer) self.sharded_rng = next_rng() self._load_model() @property def block_size(self): # return 2 * max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) return max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) * self.mesh.shape['sp'] @property def data_dim(self): return self.mesh.shape['dp'] * self.mesh.shape['fsdp'] def _load_model(self): if FLAGS.load_llama_config != '': llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config) updates = LLaMAConfig(**FLAGS.llama) llama_config.update(dict( scan_attention=updates.scan_attention, scan_mlp=updates.scan_mlp, scan_query_chunk_size=updates.scan_query_chunk_size, scan_key_chunk_size=updates.scan_key_chunk_size, scan_mlp_chunk_size=updates.scan_mlp_chunk_size, scan_layers=updates.scan_layers, param_scan_axis=updates.param_scan_axis, )) else: llama_config = LLaMAConfig(**FLAGS.llama) if FLAGS.update_llama_config != '': llama_config.update(dict(eval(FLAGS.update_llama_config))) llama_config.update(dict( bos_token_id=self.tokenizer.bos_token_id, eos_token_id=self.tokenizer.eos_token_id, )) llama_config.update(dict(mesh_dim=FLAGS.mesh_dim)) self.config = llama_config with jax.default_device(jax.devices("cpu")[0]): _, self.params = StreamingCheckpointer.load_trainstate_checkpoint( FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30 ) self.model = FlaxLLaMAForCausalLM( llama_config, input_shape=(512, self.block_size), seed=FLAGS.seed, _do_init=False, dtype=get_float_dtype_by_name(FLAGS.dtype), ) self.model_ps = match_partition_rules( LLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params ) shard_fns, _ = make_shard_and_gather_fns( self.model_ps, get_float_dtype_by_name(FLAGS.dtype) ) with self.mesh: self.params = tree_apply(shard_fns, self.params) @cached_property def _forward_generate(self): def fn(params, rng, batch): batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp')) rng_generator = JaxRNG(rng) output = self.model.generate( batch['input_ids'], attention_mask=batch['attention_mask'], params=params['params'], prng_key=rng_generator(), generation_config=GenerationConfig( max_new_tokens=self.block_size, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, temperature=0., do_sample=False, num_beams=1, top_k=50, top_p=1.0, ) ).sequences[:, batch['input_ids'].shape[1]:] return output, rng_generator() return pjit( fn, in_shardings=(self.model_ps, PS(), PS()), out_shardings=(PS(), PS()) ) def __call__(self, prompts, max_input_length): inputs = self.prefix_tokenizer( prompts, padding='max_length', truncation=True, max_length=max_input_length, return_tensors='np' ) batch = dict( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask ) with self.mesh: output, self.sharded_rng = self._forward_generate( self.params, self.sharded_rng, batch ) output = jax.device_get(output) output_text = [] for text in list(self.tokenizer.batch_decode(output, skip_special_tokens=True)): if self.tokenizer.eos_token in text: text = text.split(self.tokenizer.eos_token, maxsplit=1)[0] output_text.append(text) return output_text def main(argv): JaxDistributedConfig.initialize(FLAGS.jax_distributed) set_random_seed(FLAGS.seed) ht = LLMNeedleHaystackTester( haystack_file=FLAGS.haystack_file, context_lengths_min=FLAGS.context_lengths_min, context_lengths_max=FLAGS.context_lengths_max, context_lengths_num_intervals=FLAGS.n_context_length_intervals, document_depth_percent_intervals=FLAGS.n_document_depth_intervals, ) ht.start_test() if __name__ == "__main__": run(main) ================================================ FILE: scripts/run_eval_needle.sh ================================================ #! /bin/bash export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )" cd $PROJECT_DIR export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR" export llama_tokenizer_path="LargeWorldModel/LWM-Text-1M" export lwm_text_checkpoint="" # jsonl file containing text for haystack. Each line should be a json # with a single key "text" containing the text. export haystack_file="" export output_file="" python3 -u scripts/eval_needle.py \ --mesh_dim='!1,-1,4,1' \ --dtype='fp32' \ --load_llama_config='7b' \ --update_llama_config="dict(theta=10000000,max_sequence_length=131072,scan_attention=True,scan_query_chunk_size=1024,scan_key_chunk_size=1024,scan_mlp=True,scan_mlp_chunk_size=1024,scan_layers=True)" \ --load_checkpoint="params::$lwm_text_checkpoint" \ --tokenizer="$llama_tokenizer_path" \ --max_tokens_per_batch=5000 \ --output_file="$output_file" \ --haystack_file="$haystack_file" \ --context_lengths_min=1000 \ --context_lengths_max=10000 \ --n_context_length_intervals=20 \ --n_document_depth_intervals=20 \ --n_rounds=3 read ================================================ FILE: scripts/run_eval_needle_multi.sh ================================================ #! /bin/bash export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )" cd $PROJECT_DIR export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR" export llama_tokenizer_path="LargeWorldModel/LWM-Text-1M" export lwm_text_checkpoint="" # jsonl file containing text for haystack. Each line should be a json # with a single key "text" containing the text. export haystack_file="" export output_file="" python3 -u scripts/eval_needle_multi.py \ --mesh_dim='!1,1,-1,1' \ --dtype='fp32' \ --load_llama_config='7b' \ --update_llama_config="dict(theta=10000000,max_sequence_length=131072,scan_attention=True,scan_query_chunk_size=1024,scan_key_chunk_size=1024,scan_mlp=True,scan_mlp_chunk_size=1024,scan_layers=True)" \ --load_checkpoint="params::$lwm_text_checkpoint" \ --tokenizer="$llama_tokenizer_path" \ --max_tokens_per_batch=5000 \ --output_file="$output_file" \ --haystack_file="$haystack_file" \ --context_lengths_min=1000 \ --context_lengths_max=10000 \ --n_context_length_intervals=10 \ --n_document_depth_intervals=10 \ --n_needles_total=4 \ --n_needles_retrieve=2 \ --n_rounds=10 read ================================================ FILE: scripts/run_sample_image.sh ================================================ #! /bin/bash export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )" cd $PROJECT_DIR export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR" export llama_tokenizer_path="LargeWorldModel/LWM-Text-1M" export vqgan_checkpoint="" export lwm_checkpoint="" # Relevant params # --temperature_*: Temperature that is applied to each of the logits # --top_k_*: Only sample from the tokens with the top k logits # --cfg_scale_*: Classifier-free guidance scale for each modality # --n_frames: Number of frames to generate. For images specify 1. python3 -u -m lwm.vision_generation \ --prompt='Fireworks over the city' \ --output_file='fireworks.png' \ --temperature_image=1.0 \ --top_k_image=8192 \ --cfg_scale_image=5.0 \ --vqgan_checkpoint="$vqgan_checkpoint" \ --n_frames=1 \ --mesh_dim='!1,1,-1,1' \ --dtype='fp32' \ --load_llama_config='7b' \ --update_llama_config="dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \ --load_checkpoint="params::$lwm_checkpoint" \ --tokenizer="$llama_tokenizer_path" read ================================================ FILE: scripts/run_sample_video.sh ================================================ #! /bin/bash export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )" cd $PROJECT_DIR export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR" export llama_tokenizer_path="LargeWorldModel/LWM-Text-1M" export vqgan_checkpoint="" export lwm_checkpoint="" # Relevant params # --temperature_*: Temperature that is applied to each of the logits # --top_k_*: Only sample from the tokens with the top k logits # --cfg_scale_*: Classifier-free guidance scale for each modality # --n_frames: Number of frames to generate python3 -u -m lwm.vision_generation \ --prompt='Fireworks over the city' \ --output_file='fireworks.mp4' \ --temperature_image=1.0 \ --temperature_video=1.0 \ --top_k_image=8192 \ --top_k_video=1000 \ --cfg_scale_image=5.0 \ --cfg_scale_video=1.0 \ --vqgan_checkpoint="$vqgan_checkpoint" \ --n_frames=8 \ --mesh_dim='!1,1,-1,1' \ --dtype='fp32' \ --load_llama_config='7b' \ --update_llama_config="dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \ --load_checkpoint="params::$lwm_checkpoint" \ --tokenizer="$llama_tokenizer_path" read ================================================ FILE: scripts/run_train_text.sh ================================================ #! /bin/bash export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )" cd $PROJECT_DIR export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR" export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" export llama_tokenizer_path="LargeWorldModel/LWM-Text-1M" export dataset_path="" export output_dir="" export project_id='lwm' export experiment_note='' export experiment_id='example-text-train' # mesh_dim: dp, fsdp, tp, sp python3 -u -m lwm.train \ --modality='text' \ --mesh_dim='!1,-1,2,2' \ --dtype='fp32' \ --total_steps=200 \ --log_freq=1 \ --save_model_freq=0 \ --save_milestone_freq=10 \ --load_llama_config='debug' \ --update_llama_config="dict(theta=10000,max_sequence_length=2048,scan_attention=True,scan_query_chunk_size=256,scan_key_chunk_size=512,scan_mlp=True,scan_mlp_chunk_size=1024,scan_layers=True)" \ --tokenizer="$llama_tokenizer_path" \ --optimizer.type='adamw' \ --optimizer.accumulate_gradient_steps=1 \ --optimizer.adamw_optimizer.weight_decay=0.1 \ --optimizer.adamw_optimizer.lr=8e-5 \ --optimizer.adamw_optimizer.end_lr=8e-5 \ --optimizer.adamw_optimizer.lr_warmup_steps=5 \ --optimizer.adamw_optimizer.lr_decay_steps=200 \ --use_data_sharded_loader=True \ --train_dataset.type='json' \ --train_dataset.text_processor.fields='text' \ --train_dataset.json_dataset.path="$dataset_path" \ --train_dataset.json_dataset.seq_length=2048 \ --train_dataset.json_dataset.batch_size=1024 \ --train_dataset.json_dataset.tokenizer_processes=16 \ --train_dataset.json_dataset.use_data_sharded_loader=True \ --checkpointer.save_optimizer_state=True \ --autoresume=False \ --logger.append_uuid=False \ --logger.online=False \ --logger.project_id="$project_id" \ --logger.experiment_id="$experiment_id" \ --logger.experiment_note="$experiment_note" \ --logger.output_dir="$output_dir" \ --logger.wandb_dir="$HOME/experiment_output/$project_id" read ================================================ FILE: scripts/run_train_vision_text.sh ================================================ #! /bin/bash export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )" cd $PROJECT_DIR export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR" export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" export llama_tokenizer_path="LargeWorldModel/LWM-Text-1M" export dataset_path="" export output_dir="" export project_id='lwm' export experiment_note='' export experiment_id='example-vision-text-train' # mesh_dim: dp, fsdp, tp, sp python3 -u -m lwm.train \ --modality='vision,text' \ --mesh_dim='!1,-1,2,2' \ --dtype='fp32' \ --total_steps=200 \ --log_freq=1 \ --save_model_freq=0 \ --save_milestone_freq=10 \ --load_llama_config='debug' \ --update_llama_config="dict(theta=50000000,max_sequence_length=2048,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,scan_mlp=True,scan_mlp_chunk_size=8192,scan_layers=True)" \ --tokenizer="$llama_tokenizer_path" \ --optimizer.type='adamw' \ --optimizer.accumulate_gradient_steps=1 \ --optimizer.adamw_optimizer.weight_decay=0.1 \ --optimizer.adamw_optimizer.lr=8e-5 \ --optimizer.adamw_optimizer.end_lr=8e-5 \ --optimizer.adamw_optimizer.lr_warmup_steps=5 \ --optimizer.adamw_optimizer.lr_decay_steps=200 \ --use_data_sharded_loader=True \ --train_dataset.type='json_vision' \ --train_dataset.vision_text_processor.fields_from_example='fields' \ --train_dataset.vision_text_processor.max_n_frames=4 \ --train_dataset.json_vision_dataset.mode="no_pad" \ --train_dataset.json_vision_dataset.path="$dataset_path" \ --train_dataset.json_vision_dataset.seq_length=2048 \ --train_dataset.json_vision_dataset.batch_size=8 \ --train_dataset.json_vision_dataset.tokenizer_processes=4 \ --train_dataset.json_vision_dataset.tokenizer_parallel_chunk_size=2 \ --train_dataset.json_vision_dataset.tokenizer_parallel_batch_size=8 \ --train_dataset.json_vision_dataset.use_data_sharded_loader=True \ --checkpointer.save_optimizer_state=True \ --autoresume=False \ --logger.append_uuid=False \ --logger.online=False \ --logger.project_id="$project_id" \ --logger.experiment_id="$experiment_id" \ --logger.experiment_note="$experiment_note" \ --logger.output_dir="$output_dir" \ --logger.wandb_dir="$HOME/experiment_output/$project_id" read ================================================ FILE: scripts/run_vision_chat.sh ================================================ #! /bin/bash export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )" cd $PROJECT_DIR export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR" export llama_tokenizer_path="LargeWorldModel/LWM-Text-1M" export vqgan_checkpoint="" export lwm_checkpoint="" export input_file="" # Relevant params # --input_file: A given image file (png or jpg) or video file (any video format support by decord, e.g. mp4) # --max_n_frames: Maximum number of frames to process. If the video is longer than max_n_frames frames, it uniformly samples max_n_frames frames from the video python3 -u -m lwm.vision_chat \ --prompt="What is the video about?" \ --input_file="$input_file" \ --vqgan_checkpoint="$vqgan_checkpoint" \ --mesh_dim='!1,1,-1,1' \ --dtype='fp32' \ --load_llama_config='7b' \ --max_n_frames=8 \ --update_llama_config="dict(sample_mode='text',theta=50000000,max_sequence_length=131072,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=2048,scan_layers=True)" \ --load_checkpoint="params::$lwm_checkpoint" \ --tokenizer="$llama_tokenizer_path" \ 2>&1 | tee ~/output.log read ================================================ FILE: scripts/sample_pyt.py ================================================ import argparse from transformers import LlamaForCausalLM, LlamaTokenizer parser = argparse.ArgumentParser() parser.add_argument('-m', '--model', type=str, default='LargeWorldModel/LWM-Text-Chat-256K') args = parser.parse_args() model = LlamaForCausalLM.from_pretrained(args.model) tokenizer = LlamaTokenizer.from_pretrained(args.model) # template only relevant for chat models. non-chat models do not need this TEMPLATE = "You are a helpful assistant. USER: {} ASSISTANT:" question = "What is the capital of France?" prompt = TEMPLATE.format(question) inputs = tokenizer(prompt, return_tensors="pt") generate_ids = model.generate(inputs.input_ids, max_length=300) output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True)[0] print(output) ================================================ FILE: tpu_requirements.sh ================================================ #! /bin/bash sudo apt-get update && sudo apt-get install -y \ build-essential \ python-is-python3 \ tmux \ htop \ git \ ffmpeg # Update pip pip install --upgrade pip # Python dependencies cat > $HOME/tpu_requirements.txt <<- EndOfFile -f https://storage.googleapis.com/jax-releases/libtpu_releases.html jax[tpu]==0.4.29 flax==0.8.4 optax==0.2.2 --extra-index-url https://download.pytorch.org/whl/cpu torch==2.0.0 transformers==4.40.0 ringattention @ git+https://github.com/haoliuhl/ringattention.git datasets einops tqdm ml_collections wandb gcsfs requests typing-extensions sentencepiece tux @ git+https://github.com/haoliuhl/tux.git Pillow ffmpeg-python ipdb imageio[ffmpeg] opencv-python decord ffmpeg-python h5py psutil EndOfFile pip install --upgrade -r $HOME/tpu_requirements.txt # vim configurations cat > $HOME/.vimrc <<- EndOfFile set tabstop=4 set shiftwidth=4 set softtabstop=4 set expandtab set backspace=indent,eol,start syntax on EndOfFile # tmux configurations cat > $HOME/.tmux.conf <<- EndOfFile bind r source-file ~/.tmux.conf \; display-message "█▓░ ~/.tmux.conf reloaded." # Enable colors, https://github.com/tmux/tmux/wiki/FAQ set -g default-terminal "tmux-256color" # start with window 1 (instead of 0) set -g base-index 1 setw -g pane-base-index 1 set -g prefix C-a set -g set-titles on set -g set-titles-string '#(whoami)::#h::#(curl ipecho.net/plain;echo)' # Status bar customization set -g status-interval 5 set -g status-left-length 90 set -g status-right-length 60 set -g status-justify left # send the prefix to client inside window (ala nested sessions) bind-key a send-prefix bind-key x kill-pane # auto reorder set-option -g renumber-windows on # default window name set -g status-left "#[fg=green,bg=colour236] #S " # default statusbar colors set-option -g status-style fg=yellow,dim,bg=colour235 # default window title colors set-window-option -g window-status-style fg=yellow,bg=colour236,dim # active window title colors set-window-option -g window-status-current-style fg=brightred,bg=colour236 # basename as window title https://stackoverflow.com/a/37136828 set-window-option -g window-status-current-format '#{window_index} #{pane_current_command} #(echo "#{pane_current_path}" | rev | cut -d'/' -f-3 | rev)' set-window-option -g window-status-format '#{window_index} #{pane_current_command} #(echo "#{pane_current_path}" | rev | cut -d'/' -f-3 | rev)' # pane border set-option -g pane-border-style fg=white #base2 set-option -g pane-active-border-style fg=brightcyan #base1 # enable mouse click set -g mouse on # keep window on set -g remain-on-exit on # Longer scrollback history set -g history-limit 50000 # Scroll position indicator set -g mode-style bg=colour235,fg=colour245 # SSH agent forwarding # set-environment -g SSH_AUTH_SOCK $SSH_AUTH_SOCK if-shell '[ -n $SSH_AUTH_SOCK ]' " \ set-option -sg update-environment \"DISPLAY WINDOWID XAUTHORITY\"; \ setenv -g SSH_AUTH_SOCK /tmp/ssh_auth_sock_tmux; \ run-shell \"ln -sf $(find /tmp/ssh-* -type s -readable | head -n 1) /tmp/ssh_auth_sock_tmux\" \ " # Drag windows on the status bar bind-key -n MouseDrag1Status swap-window -t= EndOfFile # htop Configurations mkdir -p $HOME/.config/htop cat > $HOME/.config/htop/htoprc <<- EndOfFile # Beware! This file is rewritten by htop when settings are changed in the interface. # The parser is also very primitive, and not human-friendly. fields=0 48 17 18 38 39 40 2 46 47 49 1 sort_key=46 sort_direction=1 hide_threads=0 hide_kernel_threads=1 hide_userland_threads=1 shadow_other_users=0 show_thread_names=0 show_program_path=1 highlight_base_name=0 highlight_megabytes=1 highlight_threads=1 tree_view=0 header_margin=1 detailed_cpu_time=0 cpu_count_from_zero=0 update_process_names=0 account_guest_in_cpu_meter=0 color_scheme=0 delay=15 left_meters=CPU Memory Swap left_meter_modes=1 1 1 right_meters=Tasks LoadAverage Uptime right_meter_modes=2 2 2 EndOfFile