[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# local\njobs/\nlocal/\n.vscode/\n\ndata/\n*.model\n*.npy\n*.jsonl\n*.pkl\n*.json\n__pycache__/\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# Large World Model (LWM)\n\n[[Project]](https://largeworldmodel.github.io/)\n[[Paper]](https://arxiv.org/abs/2402.08268)\n[[Models]](https://huggingface.co/LargeWorldModel)\n\n**Large World Model (LWM)** is a general-purpose large-context multimodal autoregressive model. It is trained on a large dataset of diverse long videos and books using RingAttention, and can perform language, image, and video understanding and generation.\n\n\n## Approach\n\n<div align=\"center\">\n  <img src=\"./imgs/data.png\"/>\n</div>\n\nCurrent language models fall short in understanding aspects of the world not easily described in words, and struggle with complex, long-form tasks. Video sequences offer valuable temporal information absent in language and static images, making them attractive for joint modeling with language. Such models could develop a understanding of both human textual knowledge and the physical world, enabling broader AI capabilities for assisting humans. However, learning from millions of tokens of video and language sequences poses challenges due to memory constraints, computational complexity, and limited datasets. To address these challenges, we curate a large dataset of diverse videos and books, utilize the RingAttention technique to scalably train on long sequences, and gradually increase context size from 4K to 1M tokens. This paper makes the following contributions: (a) Largest context size neural network: We train one of the largest context size transformers on long video and language sequences, setting new benchmarks in difficult retrieval tasks and long video understanding. (b) Solutions for overcoming vision-language training challenges, including using masked sequence packing for mixing different sequence lengths, loss weighting to balance language and vision, and model-generated QA dataset for long sequence chat. (c) A highly-optimized implementation with RingAttention, masked sequence packing, and other key features for training on millions-length multimodal sequences. (d) Fully open-sourced a family of 7B parameter models capable of processing long text documents (LWM-Text, LWM-Text-Chat) and videos (LWM, LWM-Chat) of over 1M tokens.\nThis work paves the way for training on massive datasets of long video and language to develop understanding of both human knowledge and the multimodal world, and broader capabilities.\n\n## LWM Capabilities\n\n<div align=\"center\">\n  <img src=\"./imgs/single_needle_1M.png\"/>\n  <p>\n  LWM can retrieval facts across 1M context with high accuracy.\n  </p>\n</div>\n\n<br />\n\n<div align=\"center\">\n  <img src=\"./imgs/long_video_chat_main.png\"/>\n  <p>\n  LWM can answer questions over 1 hour YouTube video.\n  </p>\n</div>\n\n<br />\n\n<div align=\"center\">\n  <img src=\"./imgs/image_chat.png\"/>\n  <p>\n  LWM can chat with images.\n  </p>\n</div>\n\n<br />\n\n<div align=\"center\">\n  <img src=\"./imgs/image_video_gen.png\"/>\n  <p>\n  LWM can generate videos and images from text.\n  </p>\n</div>\n\n\n## Setup\n\nThis codebase is supported on Ubuntu and has not been tested on Windows or macOS. We recommend using TPUs for training and inference, although it is also possible to use GPUs. On TPU, the code is highly optimized with Jax's Pallas and can achieve high MFUs with RingAttention at very large context sizes. On GPU, the code is based on XLA and is not as optimized as it is for TPU.\n\nInstall the requirements with:\n```\nconda create -n lwm python=3.10\nconda activate lwm\npip install -r gpu_requirements.txt\n```\nor set up TPU VM with:\n```\nsh tpu_requirements.sh\n```\n\n\n## Available models\n\nThere are language-only and video-language versions, offering context sizes from 32K, to 128K, 256K and 1M tokens. The vision-language models are available only in Jax, and the language-only models are available in both PyTorch and Jax. Below are the names of the available models and their corresponding context sizes and capabilities:\n\n| Model Name         | Context Size | Language or Vision-Language | Chat or Base | URL                                                                                                                                          |\n|--------------------|--------------|-----------------------------|--------------|----------------------------------------------------------------------------------------------------------------------------------------------|\n| LWM-Text-Chat-128K | 128K         | Language                    | Chat         | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-128K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-128K-Jax)] |\n| LWM-Text-Chat-256K | 256K         | Language                    | Chat         | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-256K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-256K-Jax)] |\n| LWM-Text-Chat-512K | 512K         | Language                    | Chat         | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-512K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-512K-Jax)] |\n| LWM-Text-Chat-1M   | 1M           | Language                    | Chat         | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-1M)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-Chat-1M-Jax)]     |\n| LWM-Text-128K      | 128K         | Language                    | Base         | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-128K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-128K-Jax)]           |\n| LWM-Text-256K      | 256K         | Language                    | Base         | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-256K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-256K-Jax)]           |\n| LWM-Text-512K      | 512K         | Language                    | Base         | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-512K)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-512K-Jax)]           |\n| LWM-Text-1M        | 1M           | Language                    | Base         | [[Pytorch](https://huggingface.co/LargeWorldModel/LWM-Text-1M)][[Jax](https://huggingface.co/LargeWorldModel/LWM-Text-1M-Jax)]               |\n| LWM-Chat-32K       | 32K          | Vision-Language             | Chat         | [[Jax](https://huggingface.co/LargeWorldModel/LWM-32K-Jax)]                                                                                  |\n| LWM-Chat-128K      | 128K         | Vision-Language             | Chat         | [[Jax](https://huggingface.co/LargeWorldModel/LWM-128K-Jax)]                                                                                 |\n| LWM-Chat-1M        | 1M           | Vision-Language             | Chat         | [[Jax](https://huggingface.co/LargeWorldModel/LWM-1M-Jax)]                                                                                   |\n\n\n## Code structure\nUse `scan_query_chunk_size` and `scan_key_chunk_size` to control the block size in blockwise compute of the self-attention. Use `scan_mlp_chunk_size` to control the block size in blockwise compute of the feedforward network. Use `scan_attention=True` and `scan_mlp=True` to enable/disable blockwise compute in the self-attention and feed-forward network.\n\nYou can use `mesh_dim=dp, fsdp, tp, sp` to control the degree of parallelism and RingAttention. It is a string of 4 integers separated by commas, representing the number of data parallelism, fully sharded data parallelism, tensor parallelism, and sequence parallelism.\nFor example, `mesh_dim='1,64,4,1'` means 1 data parallelism, 64 fully sharded data parallelism, 4 tensor parallelism, and 1 sequence parallelism. `mesh_dim='1,1,4,64'` means 1 data parallelism, 1 fully sharded data parallelism, 4 tensor parallelism, and 64 sequence parallelism for RingAttention.\n\n\n## Running Jax Models\nIn this section, we provide instructions on how to run each of the provided scripts. For each script, you may need to fill in your own paths and values in the variables described in the beginning of each script.\n\nTo run each of the following scripts, use `bash <script_name>.sh`:\n- Language model training: `bash scripts/run_train_text.sh`\n- Vision-Language model training: `bash scripts/run_train_vision_text.sh`\n- Single Needle Evals (Language Model): `bash scripts/run_eval_needle.sh`\n- Multi Needle Evals (Language Model): `bash scripts/run_eval_needle_multi.sh`\n- Sampling images (Vision-Language Model): `bash scripts/run_sample_image.sh`\n- Sampling videos (Vision-LanguageModel): `bash scripts/run_sample_video.sh`\n- Image / Video understanding (Vision-Language Model): `bash scripts/run_vision_chat.sh`\n\nBy default the `mesh_dim` argument puts all devices on `tp` (tensor parallelism). For longer sequences, you may want to include `sp`, which is the last dimension in the `mesh_dim`.\n\nWhen running needle evals, you may need to adjust the `theta` and `max_sequence_length` arguments in the scripts depending on the model. Below shows the correct values for each model.\n\n|                     | LWM-Text-128K /  LWM-Text-Chat-128K | LWM-Text-256K /  LWM-Text-Chat-256K | LWM-Text-512K / LWM-Text-Chat-512K | LWM-Text-1M / LWM-Text-Chat-1M |\n|---------------------|:-----------------------------------:|:-----------------------------------:|:----------------------------------:|:------------------------------:|\n| theta               |               10000000              |               10000000              |              25000000              |            50000000            |\n| max_sequence_length |                131072               |                262144               |               524288               |             1048576            |\n\n\nAn example of filling out a script (`run_sample_video.sh`) is as follows\n```bash\n#! /bin/bash\n\nexport SCRIPT_DIR=\"$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nexport PROJECT_DIR=\"$( cd -- \"$( dirname -- \"$SCRIPT_DIR\" )\" &> /dev/null && pwd )\"\ncd $PROJECT_DIR\nexport PYTHONPATH=\"$PYTHONPATH:$PROJECT_DIR\"\n\nexport llama_tokenizer_path=\"LargeWorldModel/LWM-Text-1M\"\nexport vqgan_checkpoint=\"/path/to/ckpt/folder/vqgan\"\nexport lwm_checkpoint=\"params::/path/to/ckpt/folder/params\"\n\npython3 -u -m lwm.vision_generation \\\n    --prompt='Fireworks over the city' \\\n    --output_file='fireworks.mp4' \\\n    --temperature_image=1.0 \\\n    --temperature_video=1.0 \\\n    --top_k_image=8192 \\\n    --top_k_video=1000 \\\n    --cfg_scale_image=5.0 \\\n    --cfg_scale_video=1.0 \\\n    --vqgan_checkpoint=\"$vqgan_checkpoint\" \\\n    --n_frames=8 \\\n    --mesh_dim='!1,1,-1,1' \\\n    --dtype='fp32' \\\n    --load_llama_config='7b' \\\n    --update_llama_config=\"dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)\" \\\n    --load_checkpoint=\"$lwm_checkpoint\" \\\n    --tokenizer=\"$llama_tokenizer_path\"\nread\n```\n\n\n## Needle Haystack Data\nRun `python scripts/create_needle_data.py`\n\n\n## Running PyTorch Models\nOnly text and text chat models are currently supported for PyTorch inference. PyTorch models can be loaded as Hugging Face `LlamaForCausalLM` models. Run `python scripts/sample_pyt.py` to sample. You may need to separately install `torch`.\n\n## Documentation\n\nFor more details on the codebase, please refer to the [data.md](docs/data.md) and [sharding.md](docs/sharding.md).\nThe [data.md](docs/data.md) provides details on the data processing and the [sharding.md](docs/sharding.md) provides details on the sharding and parallelism.\n\n\n## If you have issues\n\nThis is based on the [codebase](https://github.com/haoliuhl/ringattention) of RingAttention, with the necessary features for vision-language training. The training and inference have been tested on both TPUv3 and TPUv4.\n\nIf you encounter bugs, please open a GitHub issue!\n\n\n## Citation\n\nIf you use this codebase, or otherwise found our work valuable, please cite:\n\n```\n@article{liu2023world,\n    title={World Model on Million-Length Video and Language with RingAttention},\n    author={Liu, Hao and Yan, Wilson and Zaharia, Matei and Abbeel, Pieter},\n    journal={arXiv preprint},\n    year={2024},\n}\n@article{liu2023ring,\n    title={Ring Attention with Blockwise Transformers for Near-Infinite Context},\n    author={Liu, Hao and Zaharia, Matei and Abbeel, Pieter},\n    journal={International Conference on Learning Representations},\n    year={2024}\n}\n@article{liu2023blockwise,\n    title={Blockwise Parallel Transformer for Large Context Models},\n    author={Liu, Hao and Abbeel, Pieter},\n    journal={Advances in neural information processing systems},\n    year={2023}\n}\n```\n\n## License\n\nLWM's code is released under the Apache 2.0 License. See [LICENSE](https://github.com/LargeWorldModel/lwm/blob/main/LICENSE) for further details. The models are released under the Llama-2 license.\n"
  },
  {
    "path": "docs/data.md",
    "content": "# Data\n\nWe support two types of datasets: Huggingface dataset and JSON dataset. The dataset modules are implemented in the [data.py](/lwm/data.py) file.\n\nConfiguration requires dataset type, text processor, and dataset specific configurations.\nThe following is an example of using the Huggingface dataset to train a model:\n```bash\npython -m lwm.train \\\n    --train_dataset.text_processor.fields='text' \\\n    --train_dataset.type='huggingface' \\\n    --train_dataset.huggingface_dataset.path='openwebtext'\n```\n\nIn this example, we select the Huggingface dataset by specifying the `type` of\n`train_dataset` to be `huggingface`. We then specify the path to the dataset,\nwhich is `c4` in this case. The examples loaded from the dataset will be processed\nby a TextProcessor, which is configured by the `text_processor` field.\n\nThe following options are supported for the dataset module:\n* `type`: The type of the dataset. Supported values are `huggingface` and `json`.\n* `text_processor`: The configuration of the TextProcessor used to process the\n  loaded examples.\n* `huggingface_dataset`: The configuration of the Huggingface dataset.\n* `json_dataset`: The configuration of the JSON dataset.\n\nFor huggingface dataset, we expect loading examples from a Huggingface dataset.\n* `path`: The path to the dataset. Same as the `path` argument in\n  `datasets.load_dataset`.\n* `name`: Name of the dataset within the path. Same as the `name` argument in\n  `datasets.load_dataset`.\n* `split`: The split of the dataset. Same as the `split` argument in\n  `datasets.load_dataset`.\n*  `streaming`: Whether to stream the dataset. Same as the `streaming` argument\n   in `datasets.load_dataset`.\n* `seq_length`: The length of the tokenized sequence.\n* `batch_size`: Batch size of tokenized examples.\n\nFor JSON dataset, we expect loading examples from a text file, where each line where each line represents a\nJSON encoded dictionary. Here are the configurable options for JSON dataset:\n* `path`: Path to the text file. The file can be located on the local file system\n  or on Google Cloud Storage bucket.\n* `seq_length`: The length of the tokenized sequence.\n* `batch_size`: Batch size of tokenized examples.\n* `start_seek_loc`: The starting seek location in the file. This is useful when\n  you want to resume training from a particular location in the file.\n* `index_at_start`: The counting index at the beginning. This is useful to\n  keep the index count when resuming from a particular location in the file.\n  Note that this is only for logging purpose, and does not affect the actual\n  examples starting from. To start from a different example in the dataset,\n  you should use the `start_seek_loc` option.\n* `tokenizer_processes`: The number of processes to use for tokenization.\n  Tokenization is done in parallel to speed up the loading process.\n\nA JSON dataset can be generated as follows:\n```python\nfrom datasets import load_dataset\nimport json\nfrom multiprocessing import Pool, cpu_count\n\ndataset = load_dataset(\"openwebtext\")\n\nsplit_dataset = dataset[\"train\"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)\nsplit_dataset['val'] = split_dataset.pop('test')\n\ndef save_split(split):\n    with open(f\"openwebtext_{split}.jsonl\", \"w\") as f:\n        for example in split_dataset[split]:\n            json.dump({\"text\": example[\"text\"]}, f)\n            f.write(\"\\n\")\n\nwith Pool(cpu_count()) as p:\n    p.map(save_split, [\"train\", \"val\"])\n```\n\nThis generates two files, `openwebtext_train.jsonl` and `openwebtext_val.jsonl`, which can be used as the dataset for training. Both files contain a single field, `text`, which is the text to be processed by the model.\nFor example, to train a model using the `openwebtext_train.jsonl` file, you can use the following command:\n```bash\npython -m lwm.train \\\n    --train_dataset.text_processor.fields='text' \\\n    --train_dataset.type='json' \\\n    --train_dataset.json_dataset.path='openwebtext_train.jsonl' \\\n```\n\nFor vision-langauge training, we recommend using the JSON dataset, as it allows you to pre-tokenize vision (images and videos), and load the tokenized vision along with the text.\n\nEach loaded example is a dictionary, which will be processed by a TextProcessor\n\n\n## Text Processor\nWe use the `TextProcessor` class to process the loaded examples from a dataset. This allows us to flexibly process various formats.\nEach input example is a dictionary of multiple text fields. The TextProcessor will\nprocess text fields according to its configurations, and return the final tokens.\n\nHere are the configurable options for TextProcessor:\n* `fields`: A comma separated list of text fields to process.\n* `fields_from_example`: Whether to use the keys of the input example as the\n  text fields to process. If this option is set, the `fields` argument will\n  be ignored.\n* `subfield_separator`: The text separator to use when concatenating subfields\n  of a texts.\n* `add_eos_token`: Whether to add an EOS token to the end of the text.\n* `prepend_text`: The text to prepended to the beginning.\n\nThe most important configuration for TextProcessor is the `fields` argument. It\nis a comma separated list of text fields to process. Each field consists of one\nor more subfields, which are separated by a `+`. Each subfield represent a key\nused to extract the text from the input example dictionary. The TextProcessor\njoins the extracted subfields of texts with the `subfield_separator` in the text\nlevel and then tokenize the joined text. Finally, the TextProcessor will concatenate\nthe tokenized text fields at the token level, and add the EOS token if specified.\n\nOther than the keys in the input example, you can also use the following special\nkeys to indicate a special token for a text field:\n* `<|bos|>`: Beginning of sentence token.\n* `<|eos|>`: End of sentence token.\n\nFor each text field, you can encapulate the subfields with `[]` to specify that\nthe loss should not be computed for this field. Doing so will make the loss\nmasks to be 0 for this field. This is useful when you want to use the text field\nas a prompt for the model.\n\n\nTo give a concrete example, if the input example looks like this:\n```python\n{\n    'vision': 'VQ tokens of a picture of a cat',\n    'question': 'what is the color of the cat',\n    'answer': 'the color of the cat is yellow',\n}\n```\nTo use the `vision` and `question` as the input text, and `answer` as the target,\nwe can specify the following configuration for the `fields` argument:\n```\n[vision+question],answer\n```\n\nThe `vision+question` indicates that the `vision` and `question` should be joined\ntogather with the `subfield_separator`, which is a space by default. The `[]`\nindicates that the loss should not be computed for this field. The `answer` field\nis then concatenated at the token level, where the loss will be computed.\n\n"
  },
  {
    "path": "docs/sharding.md",
    "content": "# Sharding\n\nSharding is a technique to partition the computation and the model across multiple accelerators.\nThis codebase supports flexible model and data parallelism for training and serving.\n\nThe sharding can be specified using the `mesh_dim` command line argument. The `mesh_dim` is a\ncomma separated list of integers representing the parallelism mesh axis dimensions. One of the\naxis dimensions can be `-1`, which means that the axis dimension will be inferred based on the\ntotal number of accelerators.\n\nThe first axis of the mesh is used for data parallelism (`dp`), the second axis used for fully sharded\ndata parallelism (`fsdp`), the third axis is used for tensor parallelism (`tp`), the last axis is used for\nsequence parallelism (required for ring attention) (`sp`).\n\nFor example, `mesh_dim='1,64,4,1'` means 1 data parallelism, 64 fully sharded data parallelism, 4 tensor parallelism, and 1 sequence parallelism. While `mesh_dim='1,1,4,64'` means 1 data parallelism, 1 fully sharded data parallelism, 4 tensor parallelism, and 64 sequence parallelism for RingAttention.\n\nYour total number of accelerators should be equal to the product of the mesh dimensions. For example, `mesh_dim='1,64,4,1'` requires 64 accelerators, and `mesh_dim='1,1,4,64'` requires 256 accelerators.\n\nIn general, you want to use the largest possible mesh dimension for `fsdp`. Such as `mesh_dim='1,64,1,1'` is preferred over `mesh_dim='8,8,1,1'` because the former has larger `fsdp` dimensions, which allows overlapping of computation and communication, and thus better performance.\n\nThe batch size (number of sequences per batch) should be larger than or equal to `fsdp * dp`. If you think the batch size is too large, you can allocate more accelerators to `tp` and `sp` to increase the model size and sequence length.\n\nUsing `sp` to control the sequence parallelism is required to use RingAttention. `sp=8` means sharding sequence length by 8, and `sp=1` means no sharding.\n For models that use standard attention, you can set `sp=1` and use `dp`, `fsdp`, and `tp` to control the parallelism.\n"
  },
  {
    "path": "gpu_requirements.txt",
    "content": "-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\njax[cuda12]==0.4.29\nflax==0.8.4\noptax==0.2.2\n--extra-index-url https://download.pytorch.org/whl/cpu\ntorch==2.0.0\ntransformers==4.40.0\nringattention @ git+https://github.com/haoliuhl/ringattention.git\ndatasets\neinops\ntqdm\nml_collections\nwandb\ngcsfs\nrequests\ntyping-extensions\nsentencepiece\ntux @ git+https://github.com/haoliuhl/tux.git\nPillow\nffmpeg-python\nipdb\nimageio[ffmpeg]\nopencv-python\ndecord\nffmpeg-python\nh5py\npsutil\n"
  },
  {
    "path": "lwm/__init__.py",
    "content": ""
  },
  {
    "path": "lwm/data.py",
    "content": "import time\nimport random\nfrom functools import partial\nimport json\nfrom multiprocessing import Pool\n\nfrom tux import open_file\nfrom ml_collections import ConfigDict\nimport numpy as np\nimport jax\nfrom jax.experimental.multihost_utils import host_local_array_to_global_array\nfrom jax.sharding import PartitionSpec as PS\nfrom datasets import load_dataset\n\n\nclass DatasetFactory(object):\n    \"\"\" Datset builder class. \"\"\"\n\n    @staticmethod\n    def get_default_config(updates=None):\n        config = ConfigDict()\n        config.type = 'huggingface'\n        config.text_processor = TextProcessor.get_default_config()\n        config.huggingface_dataset = HuggingfaceDataset.get_default_config()\n        config.json_dataset = JsonDataset.get_default_config()\n\n        config.vision_text_processor = VisionTextProcessor.get_default_config()\n        config.json_vision_dataset = JsonVisionDataset.get_default_config()\n\n        if updates is not None:\n            config.update(ConfigDict(updates).copy_and_resolve_references())\n        return config\n\n    @classmethod\n    def load_dataset(cls, config, tokenizer, **kwargs):\n        config = cls.get_default_config(config)\n        if config.type == 'huggingface':\n            text_processor = TextProcessor(config.text_processor, tokenizer)\n            return HuggingfaceDataset(\n                config.huggingface_dataset, tokenizer, text_processor, **kwargs\n            )\n        elif config.type == 'json':\n            text_processor = TextProcessor(config.text_processor, tokenizer)\n            return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs)\n        elif config.type == 'json_vision':\n            vision_text_processor = VisionTextProcessor(config.vision_text_processor, tokenizer)\n            return JsonVisionDataset(config.json_vision_dataset, tokenizer, vision_text_processor, **kwargs)\n        else:\n            raise ValueError(f'Unknown dataset type: {config.type}')\n\n    def __init__(self):\n        raise ValueError('DatasetFactory is a static class and should not be instantiated.')\n\n\nclass TextProcessor(object):\n    \"\"\" Example processor that converts a dictionary of texts into tokens. \"\"\"\n    @staticmethod\n    def get_default_config(updates=None):\n        config = ConfigDict()\n        config.fields_from_example = ''\n        config.fields = ''\n        config.subfield_separator = ' '\n        config.add_bos_token = True\n        config.add_eos_token = True\n        config.prepend_text = ''\n        if updates is not None:\n            config.update(ConfigDict(updates).copy_and_resolve_references())\n        return config\n\n    def __init__(self, config, tokenizer):\n        self.config = self.get_default_config(config)\n        assert self.config.fields != '' or self.config.fields_from_example != '', (\n            'Either fields or fields_from_example must be specified.'\n        )\n        self.tokenizer = tokenizer\n\n    def __call__(self, example, has_aux=False, add_bos_token=True, add_eos_token=True):\n        if has_aux:\n            example, *aux = example\n        else:\n            aux = tuple()\n        token_buffer = []\n        loss_mask_buffer = []\n\n        if add_bos_token and self.config.add_bos_token:\n            token_buffer.append(self.tokenizer.bos_token_id)\n            loss_mask_buffer.append(0.0)\n\n        if self.config.fields_from_example != '':\n            fields = example[self.config.fields_from_example].split(',')\n        else:\n            fields = self.config.fields.split(',')\n\n        for i, field in enumerate(fields):\n            if field.startswith('[') and field.endswith(']'):\n                # No loss for this field.\n                field = field[1:-1]\n                mask = 0.0\n            else:\n                mask = 1.0\n\n            if field == '<|bos|>':\n                token_buffer.append(self.tokenizer.bos_token_id)\n                loss_mask_buffer.append(mask)\n            elif field == '<|eos|>':\n                token_buffer.append(self.tokenizer.eos_token_id)\n                loss_mask_buffer.append(mask)\n            else:\n                subfields = field.split('+')\n                text = self.config.subfield_separator.join(\n                    [example[subfield] for subfield in subfields]\n                )\n                if i == 0:\n                    text = self.config.prepend_text + text\n                tokens = self.tokenizer.encode(text, add_special_tokens=False)\n                token_buffer.extend(tokens)\n                loss_mask_buffer.extend([mask for _ in range(len(tokens))])\n\n        if add_eos_token and self.config.add_eos_token:\n            token_buffer.append(self.tokenizer.eos_token_id)\n            loss_mask_buffer.append(1.0)\n\n        return token_buffer, loss_mask_buffer, *aux\n\n\nclass VisionTextProcessor(object):\n    @staticmethod\n    def get_default_config(updates=None):\n        config = ConfigDict()\n        config.fields_from_example = ''\n        config.subfield_separator = ' '\n        config.add_bos_token = True\n        config.add_eos_token = True\n        config.prepend_text = ''\n        config.fields_index = -1\n        config.eof_token = 8192 # denotes end of each frame for video generation\n        config.eov_token = 8193 # denotes end of vision generation\n        config.n_tokens_per_frame = 256 # 16 x 16 VQ codes\n        config.max_n_frames = -1\n        if updates is not None:\n            config.update(ConfigDict(updates).copy_and_resolve_references())\n        return config\n\n    def __init__(self, config, tokenizer):\n        self.config = self.get_default_config(config)\n        assert self.config.fields_from_example != '', (\n            'fields_from_example must be specified.'\n        )\n        self.tokenizer = tokenizer\n        self.vision_start = tokenizer.encode('<vision>')\n        self.vision_end = tokenizer.encode('</vision>')\n\n    def __call__(self, example, has_aux=False, add_bos_token=True, add_eos_token=True):\n        if has_aux:\n            example, *aux = example\n        else:\n            aux = tuple()\n        rand_state = random.Random(aux[-1]) # makes augmentations deterministic by line number\n        token_buffer = []\n        loss_mask_buffer = []\n        vision_mask = []\n\n        fields = example[self.config.fields_from_example]\n        if isinstance(fields, (tuple, list)):\n            if self.config.fields_index >= 0:\n                fields = fields[self.config.fields_index]\n            else:\n                # seed based on line number\n                fields = rand_state.choice(fields)\n        fields = fields.split(',')\n\n        if add_bos_token and self.config.add_bos_token:\n            token_buffer.append(self.tokenizer.bos_token_id)\n            loss_mask_buffer.append(0.0)\n            vision_mask.append(False)\n\n        for i, field in enumerate(fields):\n            if field.startswith('[') and field.endswith(']'):\n                # No loss for this field.\n                field = field[1:-1]\n                mask = 0.0\n            else:\n                mask = 1.0\n\n            if field == '<|bos|>':\n                token_buffer.append(self.tokenizer.bos_token_id)\n                loss_mask_buffer.append(mask)\n                vision_mask.append(False)\n            elif field == '<|eos|>':\n                token_buffer.append(self.tokenizer.eos_token_id)\n                loss_mask_buffer.append(mask)\n                vision_mask.append(False)\n            elif 'vision' in field:\n                vision_tokens = example[field]\n                n_frames = int(len(vision_tokens) / self.config.n_tokens_per_frame)\n                if self.config.max_n_frames > 0 and n_frames > self.config.max_n_frames: # uniformly select\n                    idxs = np.linspace(0, n_frames - 1, self.config.max_n_frames).astype(int)\n                    new_vision_tokens = []\n                    for idx in idxs:\n                        new_vision_tokens.extend(vision_tokens[idx * self.config.n_tokens_per_frame:(idx + 1) * self.config.n_tokens_per_frame])\n                    vision_tokens = new_vision_tokens\n                    n_frames = self.config.max_n_frames\n                assert int(len(vision_tokens) / self.config.n_tokens_per_frame) == n_frames, (int(len(vision_tokens) / self.config.n_tokens_per_frame), n_frames)\n\n                assert n_frames > 0, len(vision_tokens)\n                tokens = list(self.vision_start)\n                for j in range(n_frames):\n                    tokens.extend(vision_tokens[j*self.config.n_tokens_per_frame:(j+1)*self.config.n_tokens_per_frame])\n                    if j == n_frames - 1: # last frame\n                        tokens.append(self.config.eov_token)\n                    else:\n                        tokens.append(self.config.eof_token)\n                tokens.extend(self.vision_end)\n\n                token_buffer.extend(tokens)\n                loss_mask_buffer.extend([mask for _ in range(len(tokens))])\n                vision_mask.extend([False] * len(self.vision_start))\n                vision_mask.extend([True] * (self.config.n_tokens_per_frame * n_frames + n_frames)) # include extra eof/eov token at the end of each frame\n                vision_mask.extend([False] * len(self.vision_end))\n            else:\n                subfields = field.split('+')\n                text = self.config.subfield_separator.join(\n                    [example[subfield] for subfield in subfields]\n                )\n                if i == 0:\n                    text = self.config.prepend_text + text\n                tokens = self.tokenizer.encode(text)\n                token_buffer.extend(tokens)\n                loss_mask_buffer.extend([mask for _ in range(len(tokens))])\n                vision_mask.extend([False] * len(tokens))\n\n        if add_eos_token and self.config.add_eos_token:\n            token_buffer.append(self.tokenizer.eos_token_id)\n            loss_mask_buffer.append(1.0)\n            vision_mask.append(False)\n\n        assert len(token_buffer) == len(loss_mask_buffer) == len(vision_mask), (len(token_buffer), len(loss_mask_buffer), len(vision_mask))\n        keep = True\n        return token_buffer, loss_mask_buffer, vision_mask, keep, *aux\n\n\nclass HuggingfaceDataset(object):\n    \"\"\" Huggingface dataset, where the dataset is loaded using the huggingface\n        datasets.load_dataset() function.\n    \"\"\"\n\n    @staticmethod\n    def get_default_config(updates=None):\n        config = ConfigDict()\n        config.path = 'c4'\n        config.name = 'en'\n        config.split = 'train'\n        config.streaming = False\n        config.seq_length = 1024\n        config.batch_size = 8\n        config.always_start_with_bos = False\n\n        if updates is not None:\n            config.update(ConfigDict(updates).copy_and_resolve_references())\n        return config\n\n    def __init__(self, config, tokenizer, text_processor):\n        self.config = self.get_default_config(config)\n        name = self.config.name if self.config.name != '' else None\n        split = self.config.split if self.config.split != '' else None\n        self._tokenizer = tokenizer\n        self._text_processor = text_processor\n        self._dataset = load_dataset(\n            self.config.path, name, split=split, streaming=self.config.streaming\n        )\n\n    def __iter__(self):\n        chunk_size = self.config.batch_size * self.config.seq_length\n        total_tokens = 0\n        while True:\n            token_buffer = []\n            loss_mask_buffer = []\n            for index, example in enumerate(self._dataset):\n                tokens, loss_masks = self.text_processor(example)\n                token_buffer.extend(tokens)\n                loss_mask_buffer.extend(loss_masks)\n                while len(token_buffer) > chunk_size + 1:\n                    total_tokens += chunk_size\n                    metrics = {\n                        'dataset_example_index': index,\n                        'dataset_total_tokens': total_tokens,\n                    }\n                    batch = {\n                        'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(\n                            self.config.batch_size, -1\n                        ),\n                        'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(\n                            self.config.batch_size, -1\n                        ),\n                        'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(\n                            self.config.batch_size, -1\n                        ),\n                    }\n                    if self.config.always_start_with_bos:\n                        batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id\n                    yield batch, metrics\n                    token_buffer = token_buffer[chunk_size:]\n                    loss_mask_buffer = loss_mask_buffer[chunk_size:]\n\n    def get_state_dict(self):\n        return dict(config=self.config)\n\n    def load_state_dict(self, state_dict):\n        if 'config' in state_dict:\n            self.config.update(ConfigDict(state_dict['config']))\n\n    @property\n    def seq_length(self):\n        return self.config.seq_length\n\n    @property\n    def tokenizer(self):\n        return self._tokenizer\n\n    @property\n    def text_processor(self):\n        return self._text_processor\n\n    @property\n    def dataset(self):\n        return self._dataset\n\n    @property\n    def vocab_size(self):\n        return len(self._tokenizer)\n\n\nclass JsonDataset(object):\n    \"\"\" JSON dataset, where each line of the data file contains a JSON\n        dictionary with text fields.\n    \"\"\"\n\n    @staticmethod\n    def get_default_config(updates=None):\n        config = ConfigDict()\n        config.path = ''\n        config.seq_length = 1024\n        config.batch_size = 8\n        config.always_start_with_bos = False\n        config.start_seek_loc = 0\n        config.example_index_at_start = 0\n        config.tokens_count_at_start = 0\n        config.tokenizer_processes = 1\n        config.tokenizer_parallel_chunk_size = 32\n        config.tokenizer_parallel_batch_size = 1024\n        config.throughput_average_window_size = 200\n        config.pad = False\n        config.use_data_sharded_loader = True\n        config.return_local_batch = False\n\n        if updates is not None:\n            config.update(ConfigDict(updates).copy_and_resolve_references())\n        return config\n\n    def __init__(self, config, tokenizer, text_processor, node_info):\n        self.config = self.get_default_config(config)\n        assert self.config.path != ''\n        self._tokenizer = tokenizer\n        self._text_processor = text_processor\n        self._node_info = node_info\n        self._index = self.config.example_index_at_start\n        self._file_loc = self.config.start_seek_loc\n        self._total_tokens = self.config.tokens_count_at_start\n\n    def parse_json(self, line):\n        if not line or line == '\\n':\n            return None\n        try:\n            data = json.loads(line)\n        except json.decoder.JSONDecodeError:\n            print(f'Error parsing json line:\\n{line}')\n            return None\n        return data\n\n    def json_iterator(self):\n        index, file_loc = self._index, self._file_loc\n        with open_file(self.config.path, 'r') as fin:\n            fin.seek(file_loc)\n            while True:\n                line = fin.readline()\n                file_loc = fin.tell()\n                if not line:   # Reached EOF\n                    index = 0\n                    fin.seek(0)\n                    continue\n\n                data = self.parse_json(line)\n                if data is not None and (not self.config.use_data_sharded_loader or index % self._node_info['dp_node_size'] == self._node_info['dp_node_rank']):\n                    # JSON parsing succeeded\n                    yield data, file_loc, index\n                index += 1\n\n    def batched(self, iterator, batch_size):\n        batch = []\n        for example in iterator:\n            batch.append(example)\n            if len(batch) == batch_size:\n                yield batch\n                batch = []\n        if len(batch) > 0:\n            yield batch\n\n    def parallel_example_iterator(self):\n        if self.config.tokenizer_processes == 1:\n            for example, loc, index in self.json_iterator():\n                self._file_loc = loc\n                self._index = index\n                yield self.text_processor((example, loc, index), has_aux=True)\n        else:\n            process_pool = Pool(self.config.tokenizer_processes)\n            batched_iterator = self.batched(\n                self.json_iterator(), self.config.tokenizer_parallel_batch_size\n            )\n            with process_pool as pool:\n                map_fn = partial(self.text_processor, has_aux=True)\n                next_batch = pool.map_async(\n                    map_fn, next(batched_iterator),\n                    chunksize=self.config.tokenizer_parallel_chunk_size\n                )\n                while True:\n                    current_batch = next_batch\n                    next_batch = pool.map_async(\n                        map_fn, next(batched_iterator),\n                        chunksize=self.config.tokenizer_parallel_chunk_size\n                    )\n                    for example in current_batch.get():\n                        yield example\n\n    def __iter__(self):\n        global_chunk_size = self.config.batch_size * self.config.seq_length\n        if self.config.use_data_sharded_loader:\n            local_batch_size = self.config.batch_size // self._node_info['dp_node_size']\n        else:\n            local_batch_size = self.config.batch_size\n        chunk_size = local_batch_size * self.config.seq_length\n\n        token_buffer = []\n        loss_mask_buffer = []\n\n        last_time = 0.0\n        step_times = []\n        start_time = time.time()\n        start_tokens = self._total_tokens\n        for tokens, loss_masks, loc, index in self.parallel_example_iterator():\n            self._file_loc = loc\n            self._index = index\n            if self.config.pad:\n                tokens = tokens[:self.config.seq_length + 1]\n                tokens.extend([self._tokenizer.bos_token_id] * (self.config.seq_length + 1 - len(tokens)))\n                loss_masks = loss_masks[:self.config.seq_length + 1]\n                loss_masks.extend([0.0] * (self.config.seq_length + 1 - len(loss_masks)))\n            token_buffer.extend(tokens)\n            loss_mask_buffer.extend(loss_masks)\n            while len(token_buffer) > chunk_size + 1:\n                self._total_tokens += global_chunk_size\n                step_times.append(time.time() - last_time)\n                last_time = time.time()\n                if len(step_times) > self.config.throughput_average_window_size:\n                    step_times = step_times[-self.config.throughput_average_window_size:]\n                average_throughput = global_chunk_size / np.mean(step_times)\n                accumulated_throughput = (\n                    (self._total_tokens - start_tokens) / (time.time() - start_time)\n                )\n                metrics = {\n                    'dataset_file_loc': loc,\n                    'dataset_example_index': index,\n                    'dataset_total_tokens': self._total_tokens,\n                    'dataset_accumulated_tps': accumulated_throughput,\n                    'dataset_average_tps': average_throughput,\n                }\n                batch = {\n                    'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(\n                        local_batch_size, -1\n                    ),\n                    'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(\n                        local_batch_size, -1\n                    ),\n                    'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(\n                        local_batch_size, -1\n                    ),\n                }\n                batch.update({\n                    'input_vision_masks': np.zeros(batch['input_tokens'].shape, dtype=bool),\n                    'target_vision_masks': np.zeros(batch['input_tokens'].shape, dtype=bool),\n                })\n                if self.config.always_start_with_bos:\n                    batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id\n\n                if self.config.use_data_sharded_loader and not self.config.return_local_batch:\n                    mesh = self._node_info['mesh']\n                    sp_nodes_size = max(1, mesh.shape['sp'] // jax.local_device_count())\n                    sp_nodes_rank = jax.process_index() % sp_nodes_size\n                    assert self.config.seq_length % sp_nodes_size == 0, (self.config.seq_len, sp_nodes_size)\n                    seq_chunk_size = self.config.seq_length // sp_nodes_size\n                    batch = {k: v[:, sp_nodes_rank*seq_chunk_size:(sp_nodes_rank+1)*seq_chunk_size] for k, v in batch.items()}\n                    batch = host_local_array_to_global_array(batch, self._node_info['mesh'], PS(('dp', 'fsdp'), 'sp'))\n\n                yield batch, metrics\n                if self.config.pad:\n                    token_buffer, loss_mask_buffer = [], []\n                else:\n                    token_buffer = token_buffer[chunk_size:]\n                    loss_mask_buffer = loss_mask_buffer[chunk_size:]\n\n    def _make_callback(self, v):\n        return lambda index: v[index]\n\n    def get_state_dict(self):\n        return dict(\n            config=self.config,\n            index=self._index,\n            file_loc=self._file_loc,\n            total_tokens=self._total_tokens,\n        )\n\n    def load_state_dict(self, state_dict):\n        if 'config' in state_dict:\n            self.config.update(ConfigDict(state_dict['config']))\n        self._index = state_dict.get('index', self.config.example_index_at_start)\n        self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)\n        self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)\n\n    @property\n    def seq_length(self):\n        return self.config.seq_length\n\n    @property\n    def tokenizer(self):\n        return self._tokenizer\n\n    @property\n    def text_processor(self):\n        return self._text_processor\n\n    @property\n    def vocab_size(self):\n        return len(self.tokenizer)\n\n\nclass JsonVisionDataset(object):\n    @staticmethod\n    def get_default_config(updates=None):\n        config = ConfigDict()\n        config.path = ''\n        config.seq_length = 384\n        config.batch_size = 4\n        config.always_start_with_bos = False\n        config.start_seek_loc = 0\n        config.example_index_at_start = 0\n        config.tokens_count_at_start = 0\n        config.tokenizer_processes = 1\n        config.tokenizer_parallel_chunk_size = 32\n        config.tokenizer_parallel_batch_size = 1024\n        config.throughput_average_window_size = 200\n        config.use_data_sharded_loader = True\n        config.return_local_batch = False\n        config.mode = 'pad'\n\n        if updates is not None:\n            config.update(ConfigDict(updates).copy_and_resolve_references())\n        return config\n\n    def __init__(self, config, tokenizer, text_processor, node_info):\n        self.config = self.get_default_config(config)\n        assert self.config.path != ''\n        self._node_info = node_info\n        self._tokenizer = tokenizer\n        self._text_processor = text_processor\n        self._index = self.config.example_index_at_start\n        self._file_loc = self.config.start_seek_loc\n        self._total_tokens = 0\n\n    def parse_json(self, line):\n        if not line or line == '\\n':\n            return None\n        try:\n            data = json.loads(line)\n        except json.decoder.JSONDecodeError:\n            print(f'Error parsing json line:\\n{line}')\n            return None\n        return data\n\n    def json_iterator(self):\n        index, file_loc = self._index, self._file_loc\n        with open_file(self.config.path, 'r', block_size=50 * 2 ** 20) as fin:\n            fin.seek(file_loc)\n            while True:\n                line = fin.readline()\n                file_loc = fin.tell()\n                if not line:   # Reached EOF\n                    index = 0\n                    fin.seek(0)\n                    continue\n                if not self.config.use_data_sharded_loader or index % self._node_info['dp_node_size'] == self._node_info['dp_node_rank']:\n                    data = self.parse_json(line)\n                    if data is not None:\n                        # JSON parsing succeeded\n                        yield data, file_loc, index\n                index += 1\n\n    def batched(self, iterator, batch_size):\n        batch = []\n        for example in iterator:\n            batch.append(example)\n            if len(batch) == batch_size:\n                yield batch\n                batch = []\n        if len(batch) > 0:\n            yield batch\n\n    def parallel_example_iterator(self):\n        if self.config.tokenizer_processes == 1:\n            for example, loc, index in self.json_iterator():\n                self._file_loc = loc\n                self._index = index\n                yield self.text_processor((example, loc, index), has_aux=True)\n        else:\n            process_pool = Pool(self.config.tokenizer_processes)\n            batched_iterator = self.batched(\n                self.json_iterator(), self.config.tokenizer_parallel_batch_size\n            )\n            with process_pool as pool:\n                map_fn = partial(self.text_processor, has_aux=True)\n                next_batch = pool.map_async(\n                    map_fn, next(batched_iterator),\n                    chunksize=self.config.tokenizer_parallel_chunk_size\n                )\n                while True:\n                    current_batch = next_batch\n                    next_batch = pool.map_async(\n                        map_fn, next(batched_iterator),\n                        chunksize=self.config.tokenizer_parallel_chunk_size\n                    )\n                    for example in current_batch.get():\n                        yield example\n\n    def __iter__(self):\n        if self.config.mode == 'pad':\n            fn = self._iter_pad\n        elif self.config.mode == 'no_pad':\n            fn = self._iter_no_pad\n        else:\n            raise ValueError(f'Unknown mode: {self.config.mode}')\n        return fn()\n\n    def _iter_pad(self):\n        chunk_size = self.config.batch_size * self.config.seq_length\n        if self.config.use_data_sharded_loader:\n            local_batch_size = self.config.batch_size // self._node_info['dp_node_size']\n        else:\n            local_batch_size = self.config.batch_size\n        last_time = 0.0\n        buffer = []\n        step_times = []\n        start_time = time.time()\n        start_tokens = self._total_tokens\n        for tokens, loss_masks, vision_masks, keep, loc, index in self.parallel_example_iterator():\n            if not keep:\n                continue\n            self._file_loc = loc\n            self._index = index\n            buffer.append((tokens, loss_masks, vision_masks))\n            while len(buffer) >= local_batch_size:\n                self._total_tokens += chunk_size\n                step_times.append(time.time() - last_time)\n                last_time = time.time()\n                if len(step_times) > self.config.throughput_average_window_size:\n                    step_times = step_times[-self.config.throughput_average_window_size:]\n                average_throughput = chunk_size / np.mean(step_times)\n                accumulated_throughput = (\n                    (self._total_tokens - start_tokens) / (time.time() - start_time)\n                )\n                metrics = {\n                    'dataset_file_loc': loc,\n                    'dataset_example_index': index,\n                    'dataset_total_tokens': self._total_tokens,\n                    'dataset_accumulated_tps': accumulated_throughput,\n                    'dataset_average_tps': average_throughput,\n                }\n\n                batch = {\n                    'input_tokens': np.full(\n                        (local_batch_size, self.config.seq_length),\n                        self._tokenizer.bos_token_id,\n                        dtype=np.int32\n                    ),\n                    'target_tokens': np.full(\n                        (local_batch_size, self.config.seq_length),\n                        self._tokenizer.bos_token_id,\n                        dtype=np.int32\n                    ),\n                    'loss_masks': np.zeros(\n                        (local_batch_size, self.config.seq_length),\n                        dtype=np.float32\n                    ),\n                    'input_vision_masks': np.zeros(\n                        (local_batch_size, self.config.seq_length),\n                        dtype=bool\n                    ),\n                    'target_vision_masks': np.zeros(\n                        (local_batch_size, self.config.seq_length),\n                        dtype=bool\n                    )\n                }\n                for i in range(local_batch_size):\n                    tokens, loss_masks, vision_masks = buffer[i]\n                    if len(tokens) > self.config.seq_length:\n                        tokens = tokens[:self.config.seq_length + 1]\n                        loss_masks = loss_masks[1:self.config.seq_length + 1]\n                        vision_masks = vision_masks[:self.config.seq_length + 1]\n                    input_tokens, target_tokens = tokens[:-1], tokens[1:]\n                    input_vision_masks, target_vision_masks = vision_masks[:-1], vision_masks[1:]\n                    loss_masks = loss_masks[1:]\n                    batch['input_tokens'][i, :len(input_tokens)] = input_tokens\n                    batch['target_tokens'][i, :len(target_tokens)] = target_tokens\n                    batch['input_vision_masks'][i, :len(input_vision_masks)] = input_vision_masks\n                    batch['target_vision_masks'][i, :len(target_vision_masks)] = target_vision_masks\n                    batch['loss_masks'][i, :len(loss_masks)] = loss_masks\n\n                if self.config.use_data_sharded_loader and not self.config.return_local_batch:\n                    mesh = self._node_info['mesh']\n                    sp_nodes_size = max(1, mesh.shape['sp'] // jax.local_device_count())\n                    sp_nodes_rank = jax.process_index() % sp_nodes_size\n                    assert self.config.seq_length % sp_nodes_size == 0, (self.config.seq_len, sp_nodes_size)\n                    seq_chunk_size = self.config.seq_length // sp_nodes_size\n                    batch = {k: v[:, sp_nodes_rank*seq_chunk_size:(sp_nodes_rank+1)*seq_chunk_size] for k, v in batch.items()}\n                    batch = host_local_array_to_global_array(batch, self._node_info['mesh'], PS(('dp', 'fsdp'), 'sp'))\n                yield batch, metrics\n                buffer = buffer[local_batch_size:]\n\n    def _iter_no_pad(self):\n        global_chunk_size = self.config.batch_size * self.config.seq_length\n        if self.config.use_data_sharded_loader:\n            local_batch_size = self.config.batch_size // self._node_info['dp_node_size']\n        else:\n            local_batch_size = self.config.batch_size\n        chunk_size = local_batch_size * self.config.seq_length\n\n        token_buffer = []\n        loss_mask_buffer = []\n        vision_mask_buffer = []\n\n        last_time = 0.0\n        step_times = []\n        start_time = time.time()\n        start_tokens = self._total_tokens\n        for tokens, loss_masks, vision_masks, keep, loc, index in self.parallel_example_iterator():\n            if not keep:\n                continue\n            self._file_loc = loc\n            self._index = index\n            token_buffer.extend(tokens)\n            loss_mask_buffer.extend(loss_masks)\n            vision_mask_buffer.extend(vision_masks)\n            while len(token_buffer) > chunk_size + 1:\n                self._total_tokens += global_chunk_size\n                step_times.append(time.time() - last_time)\n                last_time = time.time()\n                if len(step_times) > self.config.throughput_average_window_size:\n                    step_times = step_times[-self.config.throughput_average_window_size:]\n                average_throughput = global_chunk_size / np.mean(step_times)\n                accumulated_throughput = (\n                    (self._total_tokens - start_tokens) / (time.time() - start_time)\n                )\n                metrics = {\n                    'dataset_file_loc': loc,\n                    'dataset_example_index': index,\n                    'dataset_total_tokens': self._total_tokens,\n                    'dataset_accumulated_tps': accumulated_throughput,\n                    'dataset_average_tps': average_throughput,\n                }\n                batch = {\n                    'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(\n                        local_batch_size, -1\n                    ),\n                    'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(\n                        local_batch_size, -1\n                    ),\n                    'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(\n                        local_batch_size, -1\n                    ),\n                    'input_vision_masks': np.array(vision_mask_buffer[:chunk_size], dtype=bool).reshape(\n                        local_batch_size, -1\n                    ),\n                    'target_vision_masks': np.array(vision_mask_buffer[1:chunk_size + 1], dtype=bool).reshape(\n                        local_batch_size, -1\n                    ),\n                }\n\n                if self.config.use_data_sharded_loader and not self.config.return_local_batch:\n                    mesh = self._node_info['mesh']\n                    sp_nodes_size = max(1, mesh.shape['sp'] // jax.local_device_count())\n                    sp_nodes_rank = jax.process_index() % sp_nodes_size\n                    assert self.config.seq_length % sp_nodes_size == 0, (self.config.seq_len, sp_nodes_size)\n                    seq_chunk_size = self.config.seq_length // sp_nodes_size\n                    batch = {k: v[:, sp_nodes_rank*seq_chunk_size:(sp_nodes_rank+1)*seq_chunk_size] for k, v in batch.items()}\n                    batch = host_local_array_to_global_array(batch, self._node_info['mesh'], PS(('dp', 'fsdp'), 'sp'))\n\n                yield batch, metrics\n                token_buffer = token_buffer[chunk_size:]\n                loss_mask_buffer = loss_mask_buffer[chunk_size:]\n                vision_mask_buffer = vision_mask_buffer[chunk_size:]\n\n\n    def _make_callback(self, v):\n        return lambda index: v[index]\n\n    def get_state_dict(self):\n        return dict(\n            config=self.config,\n            index=self._index,\n            file_loc=self._file_loc,\n            total_tokens=self._total_tokens,\n        )\n\n    def load_state_dict(self, state_dict):\n        if 'config' in state_dict:\n            self.config.update(ConfigDict(state_dict['config']))\n        self._index = state_dict.get('index', self.config.example_index_at_start)\n        self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)\n        self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)\n\n    @property\n    def seq_length(self):\n        return self.config.seq_length\n\n    @property\n    def tokenizer(self):\n        return self._tokenizer\n\n    @property\n    def text_processor(self):\n        return self._text_processor\n\n    @property\n    def vocab_size(self):\n        return len(self._tokenizer)\n"
  },
  {
    "path": "lwm/llama.py",
    "content": "import os\nfrom shutil import copyfile\nfrom typing import Any, Dict, List, Optional, Tuple, Union\nimport json\nimport tempfile\nfrom functools import partial\n\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import lax\nfrom jax.sharding import PartitionSpec as PS\nfrom jax.experimental.shard_map import shard_map\nimport flax.linen as nn\nfrom flax.core.frozen_dict import FrozenDict, freeze, unfreeze\nfrom flax.linen import combine_masks, make_causal_mask\nfrom flax.traverse_util import flatten_dict, unflatten_dict\nfrom flax.linen import partitioning as nn_partitioning\n\nimport sentencepiece as spm\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\nfrom transformers.tokenization_utils import PreTrainedTokenizer\nfrom transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput\nfrom transformers.modeling_flax_utils import FlaxPreTrainedModel\nfrom transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging\n\nfrom ml_collections import ConfigDict\nfrom tux import function_args_to_config, load_pickle, open_file,  with_sharding_constraint, get_jax_mesh\nfrom ringattention import ringattention, blockwise_feedforward, ringattention_jax, ringattention_inference\n\n\nLLAMA_STANDARD_CONFIGS = {\n    '200m': {\n        'vocab_size': 32000,\n        'hidden_size': 1024,\n        'intermediate_size': 2048,\n        'num_hidden_layers': 14,\n        'num_attention_heads': 8,\n        'max_sequence_length': 2048,\n        'initializer_range': 0.02,\n        'rms_norm_eps': 1e-6,\n        'use_cache': True,\n        'tie_word_embeddings': False,\n    },\n    '1b': {\n        'vocab_size': 32000,\n        'hidden_size': 2048,\n        'intermediate_size': 5504,\n        'num_hidden_layers': 22,\n        'num_attention_heads': 16,\n        'max_sequence_length': 2048,\n        'initializer_range': 0.02,\n        'rms_norm_eps': 1e-6,\n        'use_cache': True,\n        'tie_word_embeddings': False,\n    },\n    '3b': {\n        'vocab_size': 32000,\n        'hidden_size': 3200,\n        'intermediate_size': 8640,\n        'num_hidden_layers': 26,\n        'num_attention_heads': 32,\n        'max_sequence_length': 2048,\n        'initializer_range': 0.02,\n        'rms_norm_eps': 1e-6,\n        'use_cache': True,\n        'tie_word_embeddings': False,\n    },\n    '7b': {\n        'vocab_size': 32000,\n        'hidden_size': 4096,\n        'intermediate_size': 11008,\n        'num_hidden_layers': 32,\n        'num_attention_heads': 32,\n        'max_sequence_length': 4096,\n        'initializer_range': 0.02,\n        'rms_norm_eps': 1e-6,\n        'use_cache': True,\n        'tie_word_embeddings': False,\n    },\n    '13b': {\n        'vocab_size': 32000,\n        'hidden_size': 5120,\n        'intermediate_size': 13824,\n        'num_hidden_layers': 40,\n        'num_attention_heads': 40,\n        'max_sequence_length': 2048,\n        'initializer_range': 0.02,\n        'rms_norm_eps': 1e-6,\n        'use_cache': True,\n        'tie_word_embeddings': False,\n    },\n    '30b': {\n        'vocab_size': 32000,\n        'hidden_size': 6656,\n        'intermediate_size': 17920,\n        'num_hidden_layers': 60,\n        'num_attention_heads': 52,\n        'max_sequence_length': 2048,\n        'initializer_range': 0.02,\n        'rms_norm_eps': 1e-6,\n        'use_cache': True,\n        'tie_word_embeddings': False,\n    },\n    '65b': {\n        'vocab_size': 32000,\n        'hidden_size': 8192,\n        'intermediate_size': 22016,\n        'num_hidden_layers': 80,\n        'num_attention_heads': 64,\n        'max_sequence_length': 2048,\n        'initializer_range': 0.02,\n        'rms_norm_eps': 1e-5,\n        'use_cache': True,\n        'tie_word_embeddings': False,\n    },\n    'debug': { # A small model for debugging\n        'vocab_size': 32000,\n        'hidden_size': 256,\n        'intermediate_size': 256,\n        'num_hidden_layers': 2,\n        'num_attention_heads': 2,\n        'max_sequence_length': 2048,\n        'initializer_range': 0.02,\n        'rms_norm_eps': 1e-6,\n        'use_cache': True,\n        'tie_word_embeddings': False,\n    },\n}\n\n\nclass LLaMAConfig(PretrainedConfig):\n    model_type = \"llama\"\n\n    def __init__(\n        self,\n        vocab_size=32000,\n        hidden_size=4096,\n        intermediate_size=11008,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        max_sequence_length=4096,\n        rms_norm_eps=1e-6,\n        initializer_range=0.02,\n        use_cache=True,\n        bos_token_id=0,\n        eos_token_id=1,\n        resid_pdrop=0.0,\n        embd_pdrop=0.0,\n        attn_pdrop=0.0,\n        tie_word_embeddings=False,\n        scan_attention=True,\n        scan_mlp=True,\n        scan_query_chunk_size=1024,\n        scan_key_chunk_size=1024,\n        scan_mlp_chunk_size=1024,\n        scan_layers=True,\n        param_scan_axis=0,\n        mesh_dim=None,\n        theta=10000,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.initializer_range = initializer_range\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.max_sequence_length = max_sequence_length\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.resid_pdrop = resid_pdrop\n        self.embd_pdrop = embd_pdrop\n        self.attn_pdrop = attn_pdrop\n        self.scan_attention = scan_attention\n        self.scan_mlp = scan_mlp\n        self.scan_query_chunk_size = scan_query_chunk_size\n        self.scan_key_chunk_size = scan_key_chunk_size\n        self.scan_mlp_chunk_size = scan_mlp_chunk_size\n        self.scan_layers = scan_layers\n        self.param_scan_axis = param_scan_axis\n        self.mesh_dim = mesh_dim\n        self.theta = theta\n        super().__init__(\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n    @classmethod\n    def get_default_config(cls, updates=None):\n        config = function_args_to_config(cls.__init__)\n\n        if updates is not None:\n            config.update(ConfigDict(updates).copy_and_resolve_references())\n\n        return config\n\n    @staticmethod\n    def get_jax_mesh(axis_dims):\n        return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'tp', 'sp'))\n\n    @staticmethod\n    def get_ranks_and_size(mesh):\n        out = dict(mesh=mesh)\n        mp_size = mesh.shape['tp'] * mesh.shape['sp']\n        mp_node_size = max(1, mp_size // jax.local_device_count())\n        dp_node_size = jax.process_count() // mp_node_size\n        out.update(mp_node_size=mp_node_size,\n                   dp_node_size=dp_node_size)\n\n        dp_node_rank = jax.process_index() // mp_node_size\n        mp_node_rank = jax.process_index() % mp_node_size\n        out.update(dp_node_rank=dp_node_rank,\n                   mp_node_rank=mp_node_rank)\n        return out\n\n\n    @staticmethod\n    def get_partition_rules(scan_layers=False, scan_axis=0):\n        \"\"\"Parition rules are orderd, so that the beginning rules match first.\"\"\"\n        if scan_layers:\n            if scan_axis == 0:\n                return (\n                    # embeddings\n                    (\"transformer/wte/embedding\", PS(\"tp\", (\"fsdp\", \"sp\"))),\n                    # atention\n                    (\"attention/(wq|wk|wv)/kernel\", PS(None, (\"fsdp\", \"sp\"), \"tp\")),\n                    (\"attention/wo/kernel\", PS(None, \"tp\", (\"fsdp\", \"sp\"))),\n                    # mlp\n                    (\"feed_forward/w1/kernel\", PS(None, (\"fsdp\", \"sp\"), \"tp\")),\n                    (\"feed_forward/w2/kernel\", PS(None, \"tp\", (\"fsdp\", \"sp\"))),\n                    (\"feed_forward/w3/kernel\", PS(None, (\"fsdp\", \"sp\"), \"tp\")),\n                    # layer norms\n                    (\"attention_norm/kernel\", PS(None, None)),\n                    (\"ffn_norm/kernel\", PS(None, None)),\n                    # output head\n                    (\"transformer/ln_f/kernel\", PS(None)),\n                    (\"lm_head/kernel\", PS((\"fsdp\", \"sp\"), \"tp\")),\n                    ('.*', PS(None)),\n                )\n            elif scan_axis == 1:\n                return (\n                    # embeddings\n                    (\"transformer/wte/embedding\", PS(\"tp\", (\"fsdp\", \"sp\"))),\n                    # atention\n                    (\"attention/(wq|wk|wv)/kernel\", PS((\"fsdp\", \"sp\"), None, \"tp\")),\n                    (\"attention/wo/kernel\", PS(\"tp\", None, (\"fsdp\", \"sp\"))),\n                    # mlp\n                    (\"feed_forward/w1/kernel\", PS((\"fsdp\", \"sp\"), None, \"tp\")),\n                    (\"feed_forward/w2/kernel\", PS(\"tp\", None, (\"fsdp\", \"sp\"))),\n                    (\"feed_forward/w3/kernel\", PS((\"fsdp\", \"sp\"), None, \"tp\")),\n                    # layer norms\n                    (\"attention_norm/kernel\", PS(None, None)),\n                    (\"ffn_norm/kernel\", PS(None, None)),\n                    # output head\n                    (\"transformer/ln_f/kernel\", PS(None)),\n                    (\"lm_head/kernel\", PS((\"fsdp\", \"sp\"), \"tp\")),\n                    ('.*', PS(None)),\n                )\n            else:\n                raise ValueError(f\"Invalid scan_axis {scan_axis}\")\n        else:\n            return (\n                # embeddings\n                (\"transformer/wte/embedding\", PS(\"tp\", (\"fsdp\", \"sp\"))),\n                # atention\n                (\"attention/(wq|wk|wv)/kernel\", PS((\"fsdp\", \"sp\"), \"tp\")),\n                (\"attention/wo/kernel\", PS(\"tp\", (\"fsdp\", \"sp\"))),\n                # mlp\n                (\"feed_forward/w1/kernel\", PS((\"fsdp\", \"sp\"), \"tp\")),\n                (\"feed_forward/w2/kernel\", PS(\"tp\", (\"fsdp\", \"sp\"))),\n                (\"feed_forward/w3/kernel\", PS((\"fsdp\", \"sp\"), \"tp\")),\n                # layer norms\n                (\"attention_norm/kernel\", PS(None)),\n                (\"ffn_norm/kernel\", PS(None)),\n                # output head\n                (\"transformer/ln_f/kernel\", PS(None)),\n                (\"lm_head/kernel\", PS((\"fsdp\", \"sp\"), \"tp\")),\n                ('.*', PS(None)),\n            )\n\n    @staticmethod\n    def get_weight_decay_exclusions():\n        return tuple()\n\n    @staticmethod\n    def get_frozen_param_exclusions(freeze_base):\n        if freeze_base:\n            return (\"vte\", \"vision_head\")\n        else:\n            return tuple()\n\n    @staticmethod\n    def rng_keys():\n        return ('params', 'dropout')\n\n    @classmethod\n    def load_config(cls, path):\n        if path in LLAMA_STANDARD_CONFIGS:\n            return cls.from_dict(LLAMA_STANDARD_CONFIGS[path])\n        load_type, load_path = path.split('::', 1)\n        if load_type == 'pickle':\n            return cls.from_dict(load_pickle(load_path)['llama_config'])\n        elif load_type == 'json':\n            with open_file(load_path, 'r') as fin:\n                raw_config = fin.read()\n            return cls.from_dict(json.loads(raw_config))\n        else:\n            raise ValueError(f'Unsupported load config type: {load_type}')\n\n\nremat = nn_partitioning.remat\n\nlogger = logging.get_logger(__name__)\n\n\nclass RMSNorm(nn.Module):\n    dim: int\n    eps: float=1e-6\n    dtype: jnp.dtype=jnp.float32\n    param_dtype: jnp.dtype=jnp.float32\n\n    def setup(self) -> None:\n        self.weight = self.param(\n            'kernel',\n            nn.initializers.ones,\n            (self.dim,),\n            self.param_dtype,\n        )\n\n    def _norm(self, x: jnp.ndarray) -> jnp.ndarray:\n        return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps)\n\n    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n        x = x.astype(jnp.promote_types(self.dtype, jnp.float32))\n        output = self._norm(x).astype(self.dtype)\n        weight = jnp.asarray(self.weight, self.dtype)\n        return output * weight\n\n\ndef precompute_freqs_cis(dim: int, max_position_embedding: int, theta: float=10000.0, dtype: jnp.dtype=jnp.float32) -> jnp.ndarray:\n    freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))\n    t = np.arange(max_position_embedding) # type: ignore\n    freqs = np.outer(t, freqs).astype(dtype)  # type: ignore\n    sin, cos = np.sin(freqs), np.cos(freqs)\n    freqs_cis = np.complex64(cos + 1j * sin)\n    return jnp.asarray(freqs_cis)\n\n\ndef apply_rotary_emb(\n    xq: jnp.ndarray,\n    xk: jnp.ndarray,\n    freqs_cis: jnp.ndarray,\n    dtype: jnp.dtype=jnp.float32,\n) -> Tuple[jnp.ndarray, jnp.ndarray]:\n\n    reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)\n    reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)\n\n    xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])\n    xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])\n\n    # add head dim\n    freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:]))\n\n    xq_out = xq_ * freqs_cis\n    xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1)\n\n    xk_out = xk_ * freqs_cis\n    xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1)\n\n    return xq_out.astype(dtype), xk_out.astype(dtype)\n\n\nclass FlaxLLaMAAttention(nn.Module):\n    config: LLaMAConfig\n    dtype: jnp.dtype=jnp.float32\n    param_dtype: jnp.dtype=jnp.float32\n    precision: Optional[Union[jax.lax.Precision, str]]=None\n\n    def setup(self):\n        config = self.config\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.embed_dim // self.num_heads\n\n        self.wq = nn.Dense(\n            config.num_attention_heads*self.head_dim,\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n            use_bias=False,\n            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),\n            precision=self.precision,\n        )\n        self.wk = nn.Dense(\n            config.num_attention_heads*self.head_dim,\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n            use_bias=False,\n            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),\n            precision=self.precision,\n        )\n        self.wv = nn.Dense(\n            config.num_attention_heads*self.head_dim,\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n            use_bias=False,\n            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),\n            precision=self.precision,\n        )\n        self.wo = nn.Dense(\n            config.hidden_size,\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n            use_bias=False,\n            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),\n            precision=self.precision,\n        )\n\n        self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)\n\n        self.causal_mask = make_causal_mask(jnp.ones((1, config.max_sequence_length), dtype=\"bool\"), dtype=\"bool\")\n\n        self.freqs_cis = precompute_freqs_cis(\n            self.head_dim,\n            config.max_sequence_length,\n            theta=config.theta,\n            dtype=self.dtype,\n        )\n\n    def _split_heads(self, hidden_states):\n        return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))\n\n    def _merge_heads(self, hidden_states):\n        return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))\n\n    @nn.compact\n    def _concatenate_to_cache(self, key, value, query, attention_mask):\n        # detect if we're initializing by absence of existing cache data.\n        is_initialized = self.has_variable(\"cache\", \"cached_key\")\n        cached_key = self.variable(\"cache\", \"cached_key\", jnp.zeros, key.shape, key.dtype)\n        cached_value = self.variable(\"cache\", \"cached_value\", jnp.zeros, value.shape, value.dtype)\n        cache_index = self.variable(\"cache\", \"cache_index\", lambda: jnp.array(0, dtype=jnp.int32))\n\n        if is_initialized:\n            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape\n            # update key, value caches with our new 1d spatial slices\n            cur_index = cache_index.value\n            if query.shape[1] == 1:\n                mesh = LLaMAConfig.get_jax_mesh(self.config.mesh_dim)\n                def fn(cached_key, cached_value, key, value, cur_index):\n                    assert key.shape[1] == 1 and value.shape[1] == 1, (key.shape, value.shape)\n                    sp_size = max_length // mesh.shape['sp']\n                    axis_index = jax.lax.axis_index('sp')\n                    cur_index = cur_index - axis_index * sp_size\n                    key, value = jax.lax.cond(\n                        jnp.logical_and(cur_index >= 0, cur_index < sp_size),\n                        lambda: (\n                            cached_key.at[:, cur_index].set(key[:, -1]),\n                            cached_value.at[:, cur_index].set(value[:, -1]),\n                        ),\n                        lambda: (cached_key, cached_value),\n                    )\n                    return key, value\n                fn = shard_map(\n                    fn, mesh=mesh,\n                    in_specs=(\n                        PS(('dp', 'fsdp'), 'sp', 'tp', None),\n                        PS(('dp', 'fsdp'), 'sp', 'tp', None),\n                        PS(('dp', 'fsdp'), None, 'tp', None),\n                        PS(('dp', 'fsdp'), None, 'tp', None),\n                        PS()\n                    ),\n                    out_specs=(\n                        PS(('dp', 'fsdp'), 'sp', 'tp', None),\n                        PS(('dp', 'fsdp'), 'sp', 'tp', None)\n                    ),\n                    check_rep=False\n                )\n                key, value = fn(cached_key.value, cached_value.value, key, value, cur_index)\n            else:\n                indices = (0,) * len(batch_dims) + (cur_index, 0, 0)\n                key = lax.dynamic_update_slice(cached_key.value, key, indices)\n                value = lax.dynamic_update_slice(cached_value.value, value, indices)\n            cached_key.value = key\n            cached_value.value = value\n            num_updated_cache_vectors = query.shape[1]\n            cache_index.value = cache_index.value + num_updated_cache_vectors\n        return key, value, attention_mask\n\n    def __call__(\n        self,\n        hidden_states,\n        attention_mask,\n        segment_ids,\n        position_ids,\n        deterministic: bool = True,\n        init_cache: bool = False,\n        output_attentions: bool = False,\n    ):\n        xq, xk, xv = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)\n\n        if xq.shape[1] == 1:\n            xq = with_sharding_constraint(xq, PS((\"dp\", \"fsdp\"), None, \"tp\"))\n        else:\n            xq = with_sharding_constraint(xq, PS((\"dp\", \"fsdp\"), \"sp\", \"tp\"))\n        xk = with_sharding_constraint(xk, PS((\"dp\", \"fsdp\"), \"sp\", \"tp\"))\n        xv = with_sharding_constraint(xv, PS((\"dp\", \"fsdp\"), \"sp\", \"tp\"))\n\n        xq = self._split_heads(xq)\n        xk = self._split_heads(xk)\n        xv = self._split_heads(xv)\n\n        freqs_cis = jnp.take(self.freqs_cis, position_ids, axis=0)\n\n        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype)\n\n        dropout_rng = None\n        if not deterministic and self.config.attn_pdrop > 0.0:\n            dropout_rng = self.make_rng(\"dropout\")\n\n        if self.config.scan_attention and xq.shape[1] > max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size):\n            # attention mask without nxn materlization, blockwise_attn will handle the rest\n            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))\n\n            if self.has_variable(\"cache\", \"cached_key\") or init_cache:\n                xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask)\n\n            # transform boolean mask into float mask\n            attention_bias = lax.select(\n                attention_mask > 0,\n                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),\n                jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),\n            )\n            attn_weights = None\n            ring_attention_sharded = shard_map(\n                partial(\n                    ringattention,\n                    axis_name=\"sp\",\n                    float32_logits=True,\n                    cache_idx=None,\n                    blockwise_kwargs=dict(\n                        causal_block_size=1,\n                        deterministic=deterministic,\n                        dropout_rng=dropout_rng,\n                        attn_pdrop=self.config.attn_pdrop,\n                        query_chunk_size=self.config.scan_query_chunk_size,\n                        key_chunk_size=self.config.scan_key_chunk_size,\n                        dtype=self.dtype,\n                        policy=jax.checkpoint_policies.nothing_saveable,\n                        precision=self.precision,\n                        prevent_cse=not self.config.scan_layers,\n                    )\n                ),\n                mesh=LLaMAConfig.get_jax_mesh(self.config.mesh_dim),\n                in_specs=(\n                    PS((\"dp\", \"fsdp\"), \"sp\", \"tp\", None),\n                    PS((\"dp\", \"fsdp\"), \"sp\", \"tp\", None),\n                    PS((\"dp\", \"fsdp\"), \"sp\", \"tp\", None),\n                    PS((\"dp\", \"fsdp\"), None, None, None),\n                    PS((\"dp\", \"fsdp\"), None),\n                ),\n                out_specs=PS((\"dp\", \"fsdp\"), \"sp\", \"tp\", None),\n                check_rep=False\n            )\n            attn_output = ring_attention_sharded(xq, xk, xv, attention_bias, segment_ids)\n            attn_output = with_sharding_constraint(attn_output, PS((\"dp\", \"fsdp\"), \"sp\", \"tp\", None))\n        else:\n            query_length, key_length = xq.shape[1], xk.shape[1]\n\n            if self.has_variable(\"cache\", \"cached_key\"):\n                mask_shift = self.variables[\"cache\"][\"cache_index\"]\n                max_decoder_length = self.variables[\"cache\"][\"cached_key\"].shape[1]\n                causal_mask = jnp.arange(max_decoder_length)[None] <= (jnp.arange(query_length) + mask_shift)[:, None]\n                causal_mask = causal_mask[None, None]\n                segment_mask = None\n            else:\n                causal_mask = self.causal_mask[:, :, :query_length, :key_length]\n                if segment_ids is not None:\n                    segment_mask = segment_ids[:, :, None] == segment_ids[:, None, :]\n                    segment_mask = segment_mask[:, None]\n                else:\n                    segment_mask = None\n\n            batch_size = hidden_states.shape[0]\n            causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])\n\n            attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)\n            attention_mask = combine_masks(attention_mask, causal_mask, segment_mask)\n\n            # During fast autoregressive decoding, we feed one position at a time,\n            # and cache the keys and values step by step.\n            if self.has_variable(\"cache\", \"cached_key\") or init_cache:\n                xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask)\n\n            q_sp_dim = None if xq.shape[1] == 1 else 'sp'\n            attn_weights = None\n            ring_attention_sharded = shard_map(\n                partial(ringattention_inference, axis_name=\"sp\"), mesh=LLaMAConfig.get_jax_mesh(self.config.mesh_dim),\n                in_specs=(\n                    PS((\"dp\", \"fsdp\"), q_sp_dim, \"tp\", None),\n                    PS((\"dp\", \"fsdp\"), \"sp\", \"tp\", None),\n                    PS((\"dp\", \"fsdp\"), \"sp\", \"tp\", None),\n                    PS((\"dp\", \"fsdp\"), None, q_sp_dim, None)\n                ),\n                out_specs=PS((\"dp\", \"fsdp\"), q_sp_dim, \"tp\", None),\n                check_rep=False\n            )\n            attn_output = ring_attention_sharded(\n                xq, xk, xv, attention_mask\n            )\n\n        attn_output = self._merge_heads(attn_output)\n        attn_output = self.wo(attn_output)\n        attn_output = self.resid_dropout(attn_output, deterministic=deterministic)\n        outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)\n        return outputs\n\n\nclass FlaxLLaMAMLP(nn.Module):\n    config: LLaMAConfig\n    dtype: jnp.dtype=jnp.float32\n    param_dtype: jnp.dtype=jnp.float32\n    precision: Optional[Union[jax.lax.Precision, str]]=None\n\n    def setup(self) -> None:\n        config = self.config\n\n        self.w1 = nn.Dense(\n            config.intermediate_size,\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n            use_bias=False,\n            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),\n            precision=self.precision,\n        )\n        self.w2 = nn.Dense(\n            config.hidden_size,\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n            use_bias=False,\n            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),\n            precision=self.precision,\n        )\n        self.w3 = nn.Dense(\n            config.intermediate_size,\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n            use_bias=False,\n            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),\n            precision=self.precision,\n        )\n        self.dropout = nn.Dropout(rate=self.config.resid_pdrop)\n\n    def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:\n        x = self.w2(nn.silu(self.w1(x)) * self.w3(x))\n        x = self.dropout(x, deterministic=deterministic)\n        return x\n\n\nclass FlaxLLaMABlock(nn.Module):\n    config: LLaMAConfig\n    dtype: jnp.dtype=jnp.float32\n    param_dtype: jnp.dtype=jnp.float32\n    precision: Optional[Union[jax.lax.Precision, str]]=None\n\n    def setup(self) -> None:\n        attention_module = FlaxLLaMAAttention\n        mlp_module = FlaxLLaMAMLP\n        if self.config.scan_mlp:\n            mlp_module = remat(\n                mlp_module, static_argnums=(1,),\n                policy=jax.checkpoint_policies.nothing_saveable,\n                prevent_cse=not self.config.scan_layers,\n            )\n        self.attention = attention_module(\n            self.config,\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n            precision=self.precision,\n        )\n        self.feed_forward = mlp_module(\n            self.config,\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n            precision=self.precision,\n        )\n        self.attention_norm = RMSNorm(\n            self.config.hidden_size,\n            eps=self.config.rms_norm_eps,\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n        )\n        self.ffn_norm = RMSNorm(\n            self.config.hidden_size,\n            eps=self.config.rms_norm_eps,\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n        )\n\n    def __call__(\n        self,\n        hidden_states,\n        attention_mask=None,\n        segment_ids=None,\n        position_ids=None,\n        deterministic: bool = True,\n        init_cache: bool = False,\n        output_attentions: bool = False,\n    ):\n        attn_outputs = self.attention(\n            self.attention_norm(hidden_states),\n            attention_mask,\n            segment_ids,\n            position_ids,\n            deterministic,\n            init_cache,\n            output_attentions,\n        )\n        attn_output = attn_outputs[0]\n        hidden_states = hidden_states + attn_output\n\n        feed_forward_input = self.ffn_norm(hidden_states)\n\n        if self.config.scan_mlp and hidden_states.shape[1] >= self.config.scan_mlp_chunk_size:\n            feed_forward_hidden_states = blockwise_feedforward(\n                self.feed_forward,\n                feed_forward_input,\n                self.config.scan_mlp_chunk_size,\n                pre_remat=True,\n            )\n        else:\n            feed_forward_hidden_states = self.feed_forward(feed_forward_input, deterministic)\n        feed_forward_hidden_states = with_sharding_constraint(feed_forward_hidden_states, PS((\"dp\", \"fsdp\"), None, \"tp\"))\n\n        hidden_states = hidden_states + feed_forward_hidden_states\n\n        outputs = hidden_states\n        if self.config.scan_layers:\n            outputs = (outputs, None)\n        return outputs\n\n\nclass FlaxLLaMAPreTrainedModel(FlaxPreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = LLaMAConfig\n    base_model_prefix = \"transformer\"\n    module_class: nn.Module = None\n\n    def __init__(\n        self,\n        config: LLaMAConfig,\n        input_shape: Tuple = (1, 1),\n        seed: int = 0,\n        dtype: jnp.dtype = jnp.float32,\n        _do_init: bool = True,\n        **kwargs,\n    ):\n        module = self.module_class(config=config, dtype=dtype, **kwargs)\n        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)\n\n    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:\n        # init input tensors\n        input_ids = jnp.zeros(input_shape, dtype=\"i4\")\n        attention_mask = jnp.ones_like(input_ids)\n        segment_ids = None\n        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)\n        params_rng, dropout_rng = jax.random.split(rng)\n        rngs = {\"params\": params_rng, \"dropout\": dropout_rng}\n\n        if self.config.add_cross_attention:\n            encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))\n            encoder_attention_mask = attention_mask\n            module_init_outputs = self.module.init(\n                rngs,\n                input_ids,\n                attention_mask,\n                segment_ids,\n                position_ids,\n                encoder_hidden_states,\n                encoder_attention_mask,\n                return_dict=False,\n            )\n        else:\n            module_init_outputs = self.module.init(rngs, input_ids, attention_mask, segment_ids, position_ids, return_dict=False)\n\n        random_params = module_init_outputs[\"params\"]\n\n        if params is not None:\n            random_params = flatten_dict(unfreeze(random_params))\n            params = flatten_dict(unfreeze(params))\n            for missing_key in self._missing_keys:\n                params[missing_key] = random_params[missing_key]\n            self._missing_keys = set()\n            return freeze(unflatten_dict(params))\n        else:\n            return random_params\n\n    def init_cache(self, batch_size, max_length):\n        r\"\"\"\n        Args:\n            batch_size (`int`):\n                batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.\n            max_length (`int`):\n                maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized\n                cache.\n        \"\"\"\n        # init input variables to retrieve cache\n        input_ids = jnp.ones((batch_size, max_length))\n        attention_mask = jnp.ones_like(input_ids)\n        segment_ids = None\n        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)\n\n        init_variables = self.module.init(\n            jax.random.PRNGKey(0), input_ids, attention_mask, segment_ids, position_ids, return_dict=False, init_cache=True\n        )\n        return init_variables[\"cache\"]\n\n    @add_start_docstrings_to_model_forward(\"\")\n    def __call__(\n        self,\n        input_ids,\n        attention_mask=None,\n        segment_ids=None,\n        position_ids=None,\n        params: dict = None,\n        past_key_values: dict = None,\n        dropout_rng: jax.random.PRNGKey = None,\n        train: bool = False,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ):\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.return_dict\n\n        batch_size, sequence_length = input_ids.shape\n\n        if position_ids is None:\n            if past_key_values is not None:\n                raise ValueError(\"Make sure to provide `position_ids` when passing `past_key_values`.\")\n\n            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))\n\n        if attention_mask is None:\n            attention_mask = jnp.ones((batch_size, sequence_length))\n\n        # Handle any PRNG if needed\n        rngs = {}\n        if dropout_rng is not None:\n            rngs[\"dropout\"] = dropout_rng\n\n        inputs = {\"params\": params or self.params}\n\n        if past_key_values:\n            inputs[\"cache\"] = past_key_values\n            mutable = [\"cache\"]\n        else:\n            mutable = False\n\n        outputs = self.module.apply(\n            inputs,\n            jnp.array(input_ids, dtype=\"i4\"),\n            jnp.array(attention_mask, dtype=\"i4\"),\n            segment_ids,\n            jnp.array(position_ids, dtype=\"i4\"),\n            not train,\n            False,\n            output_attentions,\n            output_hidden_states,\n            return_dict,\n            rngs=rngs,\n            mutable=mutable,\n        )\n\n        # add updated cache to model output\n        if past_key_values is not None and return_dict:\n            outputs, past_key_values = outputs\n            outputs[\"past_key_values\"] = unfreeze(past_key_values[\"cache\"])\n            return outputs\n        elif past_key_values is not None and not return_dict:\n            outputs, past_key_values = outputs\n            outputs = outputs[:1] + (unfreeze(past_key_values[\"cache\"]),) + outputs[1:]\n\n        return outputs\n\n\nclass FlaxLLaMABlockCollection(nn.Module):\n    config: LLaMAConfig\n    dtype: jnp.dtype = jnp.float32\n    param_dtype: jnp.dtype=jnp.float32\n    precision: Optional[Union[jax.lax.Precision, str]]=None\n\n    @nn.compact\n    def __call__(\n        self,\n        hidden_states,\n        attention_mask=None,\n        segment_ids=None,\n        position_ids=None,\n        deterministic: bool = True,\n        init_cache: bool = False,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n    ):\n        all_attentions = () if output_attentions else None\n        all_hidden_states = () if output_hidden_states else None\n\n        block = FlaxLLaMABlock\n        if self.config.scan_layers:\n            initializing = self.is_mutable_collection('params')\n            params_spec = (\n                self.config.param_scan_axis if initializing else\n                nn_partitioning.ScanIn(self.config.param_scan_axis))\n            cache_spec = 0\n            hidden_states, _ = nn.scan(\n                block,\n                variable_axes={\n                    'params': params_spec,\n                    'cache': cache_spec,\n                    'intermediates': 0\n                },\n                split_rngs={\n                    'params': True,\n                    'dropout': True\n                },\n                in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast),\n                length=self.config.num_hidden_layers,\n                metadata_params={nn.PARTITION_NAME: 'scan_decoder_layer'},\n                )(self.config, name='scan_decoder', dtype=self.dtype, param_dtype=self.param_dtype,)(\n                    hidden_states,\n                    attention_mask,\n                    segment_ids,\n                    position_ids,\n                    deterministic,\n                    init_cache,\n                    output_attentions,\n                )\n        else:\n            blocks = [\n                block(\n                    self.config,\n                    name=str(i),\n                    dtype=self.dtype,\n                    param_dtype=self.param_dtype,\n                ) for i in range(self.config.num_hidden_layers)\n            ]\n            for block in blocks:\n                if output_hidden_states:\n                    all_hidden_states += (hidden_states,)\n\n                layer_outputs = block(\n                    hidden_states,\n                    attention_mask,\n                    segment_ids,\n                    position_ids,\n                    deterministic,\n                    init_cache,\n                    output_attentions,\n                )\n                hidden_states = layer_outputs\n\n                if output_attentions:\n                    all_attentions += (layer_outputs[1],)\n\n        outputs = (hidden_states, all_hidden_states, all_attentions)\n\n        return outputs\n\n\nclass FlaxLLaMAModule(nn.Module):\n    config: LLaMAConfig\n    dtype: jnp.dtype = jnp.float32\n    param_dtype: jnp.dtype=jnp.float32\n    precision: Optional[Union[jax.lax.Precision, str]]=None\n\n    def setup(self):\n        self.embed_dim = self.config.hidden_size\n\n        self.wte = nn.Embed(\n            self.config.vocab_size,\n            self.config.hidden_size,\n            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n        )\n        self.dropout = nn.Dropout(rate=self.config.embd_pdrop)\n        self.h = FlaxLLaMABlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision)\n        self.ln_f = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype)\n\n    def __call__(\n        self,\n        input_ids,\n        attention_mask,\n        segment_ids,\n        position_ids,\n        deterministic=True,\n        init_cache: bool = False,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n    ):\n        input_embeds = self.wte(input_ids.astype(\"i4\"))\n        assert input_embeds.shape[1] <= self.config.max_sequence_length, f\"Input sequence length {input_embeds.shape[1]} larger than max supported sequence length {self.config.max_sequence_length}\"\n\n        hidden_states = self.dropout(input_embeds, deterministic=deterministic)\n\n        outputs = self.h(\n            hidden_states,\n            attention_mask,\n            segment_ids=segment_ids,\n            position_ids=position_ids,\n            deterministic=deterministic,\n            init_cache=init_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = outputs[0]\n        hidden_states = self.ln_f(hidden_states)\n\n        if output_hidden_states:\n            all_hidden_states = outputs[1] + (hidden_states,)\n            outputs = (hidden_states, all_hidden_states) + outputs[2:]\n        else:\n            outputs = (hidden_states,) + outputs[1:]\n\n        if not return_dict:\n            return tuple(v for v in outputs if v is not None)\n\n        return FlaxBaseModelOutput(\n            last_hidden_state=hidden_states,\n            hidden_states=outputs[1],\n            attentions=outputs[-1],\n        )\n\nclass FlaxLLaMAForCausalLMModule(nn.Module):\n    config: LLaMAConfig\n    dtype: jnp.dtype = jnp.float32\n    param_dtype: jnp.dtype=jnp.float32\n    precision: Optional[Union[jax.lax.Precision, str]]=None\n\n    def setup(self):\n        self.transformer = FlaxLLaMAModule(self.config, dtype=self.dtype)\n        self.lm_head = nn.Dense(\n            self.config.vocab_size,\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n            use_bias=False,\n            kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),\n            precision=self.precision,\n        )\n\n    def __call__(\n        self,\n        input_ids,\n        attention_mask=None,\n        segment_ids=None,\n        position_ids=None,\n        deterministic: bool = True,\n        init_cache: bool = False,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n    ):\n        batch_size, seq_length = input_ids.shape\n        if attention_mask is None:\n            attention_mask = jnp.ones_like(input_ids)\n        if position_ids is None:\n            position_ids = jnp.arange(seq_length, dtype=jnp.int32)[None].repeat(batch_size, axis=0)\n        outputs = self.transformer(\n            input_ids,\n            attention_mask,\n            segment_ids,\n            position_ids,\n            deterministic=deterministic,\n            init_cache=init_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = outputs[0]\n\n        if self.config.tie_word_embeddings:\n            shared_kernel = self.transformer.variables[\"params\"][\"wte\"][\"embedding\"].T\n            lm_logits = self.lm_head.apply({\"params\": {\"kernel\": shared_kernel}}, hidden_states)\n        else:\n            lm_logits = self.lm_head(hidden_states)\n\n        if not return_dict:\n            return (lm_logits,) + outputs[1:]\n\n        return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)\n\n\n@add_start_docstrings(\"\", \"\")\nclass FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel):\n    module_class = FlaxLLaMAForCausalLMModule\n\n    def prepare_inputs_for_generation(\n        self, input_ids, max_length,\n        attention_mask: Optional[jax.Array] = None,\n    ):\n        # initializing the cache\n        batch_size, seq_length = input_ids.shape\n\n        past_key_values = self.init_cache(batch_size, max_length)\n        extended_attention_mask = jnp.ones((batch_size, max_length), dtype=\"i4\")\n        if attention_mask is not None:\n            position_ids = attention_mask.cumsum(axis=-1) - 1\n            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))\n        else:\n            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype=\"i4\")[None, :], (batch_size, seq_length))\n\n        return {\n            \"past_key_values\": past_key_values,\n            \"attention_mask\": extended_attention_mask,\n            \"position_ids\": position_ids,\n        }\n\n    def update_inputs_for_generation(self, model_outputs, model_kwargs):\n        model_kwargs[\"past_key_values\"] = model_outputs.past_key_values\n        model_kwargs[\"position_ids\"] = model_kwargs[\"position_ids\"][:, -1:] + 1\n        return model_kwargs\n"
  },
  {
    "path": "lwm/train.py",
    "content": "import pprint\nimport os\nfrom functools import partial\n\nfrom tqdm import tqdm, trange\nimport numpy as np\nfrom absl.app import run\nimport absl.logging as logging\nimport tux\n\nimport jax\nimport flax\nimport jax.numpy as jnp\nfrom jax.experimental.pjit import pjit\nfrom jax.sharding import PartitionSpec as PS\nfrom flax.training.train_state import TrainState\nfrom transformers import AutoTokenizer\n\nfrom lwm.data import DatasetFactory\nfrom tux import (\n    JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,\n    cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,\n    set_random_seed, average_metrics, get_mask,\n    make_shard_and_gather_fns, with_sharding_constraint, define_flags_with_default,\n    OptimizerFactory, StreamingCheckpointer\n)\nfrom lwm.llama import LLaMAConfig, FlaxLLaMAForCausalLMModule\nfrom lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLMModule\n\n\nFLAGS, FLAGS_DEF = define_flags_with_default(\n    modality='text',\n    use_data_sharded_loader=True,\n    seed=42,\n    mesh_dim='1,-1,1,1',\n    dtype='fp32',\n    total_steps=10000,\n    load_llama_config='',\n    update_llama_config='',\n    load_checkpoint='',\n    load_dataset_state='',\n    log_freq=50,\n    save_model_freq=0,\n    save_milestone_freq=0,\n    eval_steps=0,\n    tokenizer='LargeWorldModel/LWM-Text-1M',\n    train_dataset=DatasetFactory.get_default_config(),\n    eval_dataset=DatasetFactory.get_default_config(),\n    optimizer=OptimizerFactory.get_default_config(),\n    checkpointer=StreamingCheckpointer.get_default_config(),\n    llama=VideoLLaMAConfig.get_default_config(),\n    logger=tux.WandBLogger.get_default_config(),\n    log_all_worker=False,\n    jax_distributed=JaxDistributedConfig.get_default_config(),\n    autoresume=False,\n)\n\n\ndef main(argv):\n    JaxDistributedConfig.initialize(FLAGS.jax_distributed)\n    variant = tux.get_user_flags(FLAGS, FLAGS_DEF)\n    flags_config_dict = tux.user_flags_to_config_dict(FLAGS, FLAGS_DEF)\n\n    logger = tux.WandBLogger(\n        config=FLAGS.logger,\n        variant=variant,\n        enable=FLAGS.log_all_worker or (jax.process_index() == 0),\n    )\n    set_random_seed(FLAGS.seed)\n\n    if jax.process_index() == 0:\n        output_dir = logger.output_dir\n    else:\n        output_dir = os.path.join(logger.output_dir, logger.experiment_id)\n\n    if FLAGS.modality == 'text':\n        config_cls = LLaMAConfig\n        llama_cls = FlaxLLaMAForCausalLMModule\n    elif FLAGS.modality == 'vision,text':\n        config_cls = VideoLLaMAConfig\n        llama_cls = FlaxVideoLLaMAForCausalLMModule\n    else:\n        raise ValueError(f\"Unsupported modality: {FLAGS.modality}\")\n\n    mesh = config_cls.get_jax_mesh(FLAGS.mesh_dim)\n    node_info = config_cls.get_ranks_and_size(mesh)\n\n    tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer)\n    dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer, node_info=node_info)\n    if FLAGS.autoresume and tux.check_exists(output_dir):\n        logging.info('Found existing output. Resuming dataset from latest checkpoint...')\n        resume_path = f\"{output_dir}/dataset.pkl\"\n        dataset.load_state_dict(tux.load_pickle(resume_path))\n    elif FLAGS.load_dataset_state != '':\n        dataset.load_state_dict(tux.load_pickle(FLAGS.load_dataset_state))\n\n    if FLAGS.eval_steps > 0:\n        eval_dataset = DatasetFactory.load_dataset(\n            FLAGS.eval_dataset, dataset.tokenizer\n        )\n        eval_iterator = iter(eval_dataset)\n\n    seq_length = dataset.seq_length\n\n    if FLAGS.load_llama_config != '':\n        llama_config = config_cls.load_config(FLAGS.load_llama_config)\n        updates = config_cls(**FLAGS.llama)\n        llama_config.update(dict(\n            scan_attention=updates.scan_attention,\n            scan_mlp=updates.scan_mlp,\n            scan_query_chunk_size=updates.scan_query_chunk_size,\n            scan_key_chunk_size=updates.scan_key_chunk_size,\n            scan_mlp_chunk_size=updates.scan_mlp_chunk_size,\n            scan_layers=updates.scan_layers,\n            param_scan_axis=updates.param_scan_axis,\n        ))\n    else:\n        llama_config = config_cls(**FLAGS.llama)\n\n    if FLAGS.update_llama_config != '':\n        llama_config.update(dict(eval(FLAGS.update_llama_config)))\n\n    llama_config.update(dict(\n        bos_token_id=dataset.tokenizer.bos_token_id,\n        eos_token_id=dataset.tokenizer.eos_token_id,\n    ))\n    if llama_config.vocab_size < dataset.vocab_size:\n        llama_config.update(dict(vocab_size=dataset.vocab_size))\n    llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))\n\n    model = llama_cls(\n        llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype)\n    )\n\n    optimizer, optimizer_info = OptimizerFactory.get_optimizer(\n        FLAGS.optimizer,\n        get_mask(config_cls.get_weight_decay_exclusions()),\n        None,\n    )\n\n    def create_trainstate_from_params(params):\n        return TrainState.create(params=params, tx=optimizer, apply_fn=None)\n\n    def init_fn(rng):\n        rng_generator = JaxRNG(rng)\n        batch = 512\n        if FLAGS.modality == 'text':\n            params = model.init(\n                input_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32),\n                position_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32),\n                attention_mask=jnp.ones((batch, seq_length), dtype=jnp.int32),\n                rngs=rng_generator(llama_config.rng_keys()),\n            )\n        elif FLAGS.modality == 'vision,text':\n            params = model.init(\n                input_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32),\n                vision_masks=jnp.zeros((batch, seq_length), dtype=bool),\n                position_ids=jnp.zeros((batch, seq_length), dtype=jnp.int32),\n                attention_mask=jnp.ones((batch, seq_length), dtype=jnp.int32),\n                rngs=rng_generator(llama_config.rng_keys()),\n            )\n        else:\n            raise ValueError(f\"Unsupported modality: {FLAGS.modality}\")\n        return TrainState.create(params=params, tx=optimizer, apply_fn=None)\n\n    def train_step(train_state, rng, batch):\n        rng_generator = JaxRNG(rng)\n        batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))\n        def loss_and_accuracy(params):\n            if FLAGS.modality == 'text':\n                logits = model.apply(\n                    params,\n                    batch['input_tokens'],\n                    deterministic=False,\n                    rngs=rng_generator(llama_config.rng_keys()),\n                ).logits\n                loss, acc = cross_entropy_loss_and_accuracy(\n                    logits,\n                    batch['target_tokens'],\n                    batch['loss_masks']\n                )\n                metrics = dict(acc=acc)\n                return loss, metrics\n            elif FLAGS.modality == 'vision,text':\n                vision_logits, text_logits = model.apply(\n                    params,\n                    batch['input_tokens'],\n                    batch['input_vision_masks'],\n                    deterministic=False,\n                    rngs=rng_generator(llama_config.rng_keys()),\n                ).logits\n                vision_loss, vision_acc = cross_entropy_loss_and_accuracy(\n                    vision_logits,\n                    jnp.where(batch['target_vision_masks'], batch['target_tokens'], 0),\n                    batch['loss_masks'] * batch['target_vision_masks']\n                )\n                text_loss, text_acc = cross_entropy_loss_and_accuracy(\n                    text_logits,\n                    jnp.where(batch['target_vision_masks'], 0, batch['target_tokens']),\n                    batch['loss_masks'] * (1.0 - batch['target_vision_masks'])\n                )\n                loss = 0.5 * (vision_loss + text_loss)\n\n                metrics = dict(\n                    vision_loss=vision_loss,\n                    vision_acc=vision_acc,\n                    text_loss=text_loss,\n                    text_acc=text_acc,\n                )\n            else:\n                raise ValueError(f\"Unsupported modality: {FLAGS.modality}\")\n            return loss, metrics\n        grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)\n        (loss, loss_metrics), grads = grad_fn(train_state.params)\n        train_state = train_state.apply_gradients(grads=grads)\n        metrics = dict(\n            loss=loss,\n            learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),\n            param_norm=global_norm(train_state.params),\n            gradient_norm=global_norm(grads),\n            **loss_metrics\n        )\n        return train_state, rng_generator(), metrics\n\n    def eval_step(train_state, rng, batch):\n        rng_generator = JaxRNG(rng)\n        batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))\n        if FLAGS.modality == 'text':\n            logits = model.apply(\n                train_state.params,\n                batch['input_tokens'],\n                deterministic=True,\n                rngs=rng_generator(llama_config.rng_keys()),\n            ).logits\n            loss, acc = cross_entropy_loss_and_accuracy(\n                logits,\n                batch['target_tokens'],\n                batch['loss_masks']\n            )\n            metrics = dict(\n                eval_loss=loss,\n                eval_acc=acc,\n            )\n        elif FLAGS.modality == 'vision,text':\n            vision_logits, text_logits = model.apply(\n                train_state.params,\n                batch['input_tokens'],\n                batch['input_vision_masks'],\n                deterministic=True,\n                rngs=rng_generator(llama_config.rng_keys()),\n            ).logits\n            vision_loss, vision_acc = cross_entropy_loss_and_accuracy(\n                vision_logits,\n                jnp.where(batch['target_vision_masks'], batch['target_tokens'], 0),\n                batch['loss_masks'] * batch['target_vision_masks']\n            )\n            text_loss, text_acc = cross_entropy_loss_and_accuracy(\n                text_logits,\n                jnp.where(batch['target_vision_masks'], 0, batch['target_tokens']),\n                batch['loss_masks'] * (1.0 - batch['target_vision_masks'])\n            )\n            loss = 0.5 * (vision_loss + text_loss)\n            metrics = dict(\n                eval_loss=loss,\n                eval_vision_accuracy=vision_acc,\n                eval_vision_loss=vision_loss,\n                eval_text_accuracy=text_acc,\n                eval_text_loss=text_loss,\n            )\n        return rng_generator(), metrics\n\n    train_state_shapes = jax.eval_shape(init_fn, next_rng())\n    train_state_partition = match_partition_rules(\n        config_cls.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), train_state_shapes\n    )\n\n    shard_fns, gather_fns = make_shard_and_gather_fns(\n        train_state_partition, train_state_shapes\n    )\n    checkpointer = StreamingCheckpointer(\n        FLAGS.checkpointer, logger.output_dir,\n        enable=jax.process_index() == 0,\n    )\n\n    sharded_init_fn = pjit(\n        init_fn,\n        in_shardings=PS(),\n        out_shardings=train_state_partition\n    )\n\n    sharded_create_trainstate_from_params = pjit(\n        create_trainstate_from_params,\n        in_shardings=(train_state_partition.params, ),\n        out_shardings=train_state_partition,\n        donate_argnums=(0, ),\n    )\n\n    if FLAGS.use_data_sharded_loader:\n        batch_spec = PS(('dp', 'fsdp'), 'sp')\n    else:\n        batch_spec = PS()\n    sharded_train_step = pjit(\n        train_step,\n        in_shardings=(train_state_partition, PS(), batch_spec),\n        out_shardings=(train_state_partition, PS(), PS()),\n        donate_argnums=(0, 1),\n    )\n\n    sharded_eval_step = pjit(\n        eval_step,\n        in_shardings=(train_state_partition, PS(), PS()),\n        out_shardings=(PS(), PS()),\n        donate_argnums=(1,),\n    )\n\n    def save_checkpoint(train_state, milestone=False):\n        step = int(jax.device_get(train_state.step))\n        metadata = dict(\n            step=step,\n            variant=variant,\n            flags=flags_config_dict,\n            llama_config=llama_config.to_dict(),\n        )\n        checkpointer.save_all(\n            train_state=train_state,\n            gather_fns=gather_fns,\n            metadata=metadata,\n            dataset=dataset.get_state_dict(),\n            milestone=milestone,\n        )\n\n    with mesh:\n        train_state, restored_params = None, None\n\n        if FLAGS.autoresume and tux.check_exists(output_dir):\n            logging.info('Found existing output. Resuming model from latest checkpoint...')\n            resume_path = f\"trainstate::{output_dir}/streaming_train_state\"\n            train_state, restored_params = checkpointer.load_trainstate_checkpoint(\n                resume_path, train_state_shapes, shard_fns, max_buffer_size=32 * 2 ** 30\n            )\n        elif FLAGS.load_checkpoint != '':\n            train_state, restored_params = checkpointer.load_trainstate_checkpoint(\n                FLAGS.load_checkpoint, train_state_shapes, shard_fns, max_buffer_size=32 * 2 ** 30\n            )\n\n        if train_state is None and restored_params is None:\n            # Initialize from scratch\n            train_state = sharded_init_fn(next_rng())\n        elif train_state is None and restored_params is not None:\n            # Restore from params but initialize train_state\n            train_state = sharded_create_trainstate_from_params(flax.core.unfreeze(restored_params))\n            del restored_params\n\n        start_step = int(jax.device_get(train_state.step))\n\n        if FLAGS.save_model_freq > 0:\n            save_checkpoint(train_state)\n\n        sharded_rng = next_rng()\n\n        step_counter = trange(start_step, FLAGS.total_steps, ncols=0)\n        for step, (batch, dataset_metrics) in zip(step_counter, dataset):\n            train_state, sharded_rng, metrics = sharded_train_step(\n                train_state, sharded_rng, batch\n            )\n            if step % FLAGS.log_freq == 0:\n                if FLAGS.eval_steps > 0:\n                    eval_metric_list = []\n                    for _ in range(FLAGS.eval_steps):\n                        eval_batch, _ = next(eval_iterator)\n                        sharded_rng, eval_metrics = sharded_eval_step(\n                            train_state, sharded_rng, eval_batch\n                        )\n                        eval_metrics = jax.device_get(eval_metrics)\n                        eval_metric_list.append(eval_metrics)\n                    metrics.update(average_metrics(eval_metric_list))\n\n                log_metrics = {\"step\": step}\n                log_metrics.update(metrics)\n                log_metrics.update(dataset_metrics)\n                log_metrics = jax.device_get(log_metrics)\n                logger.log(log_metrics)\n                tqdm.write(\"\\n\" + pprint.pformat(log_metrics) + \"\\n\")\n\n            if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:\n                save_checkpoint(train_state, milestone=True)\n            elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:\n                save_checkpoint(train_state)\n\n        if FLAGS.save_model_freq > 0:\n            save_checkpoint(train_state)\n\n\nif __name__ == \"__main__\":\n    run(main)\n"
  },
  {
    "path": "lwm/vision_chat.py",
    "content": "from absl.app import run\nimport math\nfrom tqdm import tqdm\nfrom PIL import Image\nimport decord\nfrom functools import cached_property\nimport numpy as np\nimport jax\nfrom jax.experimental.pjit import pjit\nfrom jax.sharding import PartitionSpec as PS\nfrom transformers import GenerationConfig, AutoTokenizer\nfrom tux import (\n    define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,\n    set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng,\n    match_partition_rules, make_shard_and_gather_fns,\n    with_sharding_constraint, tree_apply, open_file\n)\nfrom lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM\nfrom lwm.vqgan import VQGAN\n\n\nFLAGS, FLAGS_DEF = define_flags_with_default(\n    prompt=\"\",\n    input_file=\"\",\n    vqgan_checkpoint=\"\",\n    temperature=0.2,\n    max_n_frames=8,\n    seed=1234,\n    mesh_dim='1,-1,1,1',\n    dtype='fp32',\n    load_llama_config='',\n    update_llama_config='',\n    load_checkpoint='',\n    tokenizer='LargeWorldModel/LWM-Text-1M',\n    llama=VideoLLaMAConfig.get_default_config(),\n    jax_distributed=JaxDistributedConfig.get_default_config(),\n)\n\n\nclass Sampler:\n    def __init__(self):\n        self.mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)\n        self.vqgan = VQGAN(FLAGS.vqgan_checkpoint, replicate=False)\n        self.prefix_tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer, truncation_side='left', padding_side='left')\n        self.tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer)\n        self.n_tokens_per_frame = 257\n        self.min_buffer_size = 256\n        self.sharded_rng = next_rng()\n        self._load_model()\n\n    @property\n    def block_size(self):\n        return max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) * self.mesh.shape['sp']\n\n    @property\n    def data_dim(self):\n        return self.mesh.shape['dp'] * self.mesh.shape['fsdp']\n\n    def _process_frame(self, image, size):\n        width, height = image.size\n        if width < height:\n            new_width = size\n            new_height = int(size * height / width)\n        else:\n            new_height = size\n            new_width = int(size * width / height)\n        image = image.resize((new_width, new_height))\n\n        left = (new_width - size) / 2\n        top = (new_height - size) / 2\n        right = (new_width + size) / 2\n        bottom = (new_height + size) / 2\n        image = image.crop((left, top, right, bottom))\n        return np.array(image, dtype=np.float32) / 127.5 - 1\n\n    def _read_process_vision(self, path, max_n_frames):\n        f = open_file(path, 'rb')\n        if path.endswith('.png') or path.endswith('.jpg'):\n            image = Image.open(f).convert('RGB')\n            vision = self._process_frame(image, 256)[None]\n        else:\n            vr = decord.VideoReader(f, ctx=decord.cpu(0))\n            duration = len(vr)\n            if duration <= max_n_frames:\n                frame_id_list = list(range(duration))\n            else:\n                frame_id_list = np.linspace(0, duration - 1, max_n_frames, dtype=int).tolist()\n            video = vr.get_batch(frame_id_list).asnumpy()\n            vision = np.stack([self._process_frame(Image.fromarray(frame), 256) for frame in video])\n\n        B = 1\n        encodings = []\n        for i in range(0, len(vision), 1):\n            v = vision[i:i + B]\n            if len(v) % B == 0:\n                n_pad = 0\n            else:\n                n_pad = B - len(v) % B\n            v = np.pad(v, ((n_pad, 0), (0, 0), (0, 0), (0, 0)))\n            enc = jax.device_get(self.vqgan.encode(v))[1].astype(int)\n            enc = enc[n_pad:]\n            for t in range(len(enc)):\n                encodings.extend(enc[t].reshape(-1).tolist())\n                if t == len(enc) - 1:\n                    encodings.append(8193)\n                else:\n                    encodings.append(8192)\n        return encodings\n\n    def construct_input(self, prompts, max_n_frames):\n        max_input_length = max_n_frames * self.n_tokens_per_frame + self.min_buffer_size\n        max_input_length = int(math.ceil(max_input_length / self.block_size) * self.block_size)\n\n        vision_start = self.tokenizer.encode('<vision>')\n        vision_end = self.tokenizer.encode('</vision>')\n\n        input_ids = np.zeros((len(prompts), max_input_length), dtype=int)\n        vision_masks = np.zeros((len(prompts), max_input_length), dtype=bool)\n        attention_mask = np.zeros((len(prompts), max_input_length), dtype=int)\n        for i, prompt in enumerate(tqdm(prompts)):\n            vision = self._read_process_vision(prompt['input_path'], max_n_frames)\n            text_1 = self.tokenizer.encode(f\"<s>You are a helpful assistant. USER: {prompt['question']}\\n\")\n            tail = self.tokenizer.encode(\" ASSISTANT:\")\n\n            tokens, vm = [], []\n            tokens.extend(text_1)\n            vm.extend([False] * len(text_1))\n            tokens.extend(vision_start)\n            vm.extend([False] * len(vision_start))\n            tokens.extend(vision)\n            vm.extend([True] * len(vision))\n            tokens.extend(vision_end)\n            vm.extend([False] * len(vision_end))\n            tokens.extend(tail)\n            vm.extend([False] * len(tail))\n            assert len(tokens) < max_input_length, (len(tokens), max_input_length)\n            assert len(tokens) == len(vm)\n            input_ids[i, -len(tokens):] = tokens\n            vision_masks[i, -len(tokens):] = vm\n            attention_mask[i, -len(tokens):] = 1\n        return {\n            'input_ids': input_ids,\n            'vision_masks': vision_masks,\n            'attention_mask': attention_mask\n        }\n\n\n    def _load_model(self):\n        if FLAGS.load_llama_config != '':\n            llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config)\n            updates = VideoLLaMAConfig(**FLAGS.llama)\n            llama_config.update(dict(\n                scan_attention=updates.scan_attention,\n                scan_mlp=updates.scan_mlp,\n                scan_query_chunk_size=updates.scan_query_chunk_size,\n                scan_key_chunk_size=updates.scan_key_chunk_size,\n                scan_mlp_chunk_size=updates.scan_mlp_chunk_size,\n                scan_layers=updates.scan_layers,\n                param_scan_axis=updates.param_scan_axis,\n            ))\n        else:\n            llama_config = VideoLLaMAConfig(**FLAGS.llama)\n\n        if FLAGS.update_llama_config != '':\n            llama_config.update(dict(eval(FLAGS.update_llama_config)))\n\n        llama_config.update(dict(\n            bos_token_id=self.tokenizer.bos_token_id,\n            eos_token_id=self.tokenizer.eos_token_id,\n        ))\n        llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))\n        self.config = llama_config\n\n        self.model = FlaxVideoLLaMAForCausalLM(\n            llama_config,\n            input_shape=(512, self.block_size),\n            seed=FLAGS.seed,\n            _do_init=False,\n            dtype=get_float_dtype_by_name(FLAGS.dtype),\n        )\n\n        with jax.default_device(jax.devices(\"cpu\")[0]):\n            _, self.params = StreamingCheckpointer.load_trainstate_checkpoint(\n                    FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30\n            )\n        self.model_ps = match_partition_rules(\n            VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params\n        )\n        shard_fns, _ = make_shard_and_gather_fns(\n            self.model_ps, get_float_dtype_by_name(FLAGS.dtype)\n        )\n\n        with self.mesh:\n            self.params = tree_apply(shard_fns, self.params)\n\n    @cached_property\n    def _forward_generate(self):\n        def fn(params, rng, batch):\n            batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))\n            rng_generator = JaxRNG(rng)\n            output = self.model.generate(\n                batch['input_ids'],\n                vision_masks=batch['vision_masks'],\n                attention_mask=batch['attention_mask'],\n                params=params['params'],\n                prng_key=rng_generator(),\n                generation_config=GenerationConfig(\n                    max_new_tokens=self.block_size,\n                    pad_token_id=self.tokenizer.pad_token_id,\n                    eos_token_id=self.tokenizer.eos_token_id,\n                    temperature=FLAGS.temperature,\n                    do_sample=True,\n                )\n            ).sequences[:, batch['input_ids'].shape[1]:]\n            return output, rng_generator()\n        return pjit(\n            fn,\n            in_shardings=(self.model_ps, PS(), PS()),\n            out_shardings=(PS(), PS())\n        )\n\n    def __call__(self, prompts, max_n_frames):\n        batch = self.construct_input(prompts, max_n_frames)\n        with self.mesh:\n            output, self.sharded_rng = self._forward_generate(\n                self.params, self.sharded_rng, batch\n            )\n            output = jax.device_get(output)\n        output_text = []\n        for text in list(self.tokenizer.batch_decode(output, skip_special_tokens=True)):\n            if self.tokenizer.eos_token in text:\n                text = text.split(self.tokenizer.eos_token, maxsplit=1)[0]\n            output_text.append(text)\n        return output_text\n\ndef main(argv):\n    assert FLAGS.prompt != ''\n    assert FLAGS.input_file != ''\n\n    JaxDistributedConfig.initialize(FLAGS.jax_distributed)\n    set_random_seed(FLAGS.seed)\n\n    prompts = [{'input_path': FLAGS.input_file, 'question': FLAGS.prompt}]\n    sampler = Sampler()\n    output = sampler(prompts, FLAGS.max_n_frames)[0]\n    print(f\"Question: {FLAGS.prompt}\\nAnswer: {output}\")\n\nif __name__ == \"__main__\":\n    run(main)\n"
  },
  {
    "path": "lwm/vision_generation.py",
    "content": "from absl.app import run\nfrom tqdm import tqdm\nimport imageio\nimport numpy as np\nfrom PIL import Image\nfrom transformers import GenerationConfig, AutoTokenizer\nimport jax\nimport jax.numpy as jnp\nfrom jax.experimental.pjit import pjit\nfrom jax.sharding import PartitionSpec as PS\nfrom tux import (\n    define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,\n    set_random_seed, get_float_dtype_by_name, JaxRNG,\n    match_partition_rules, make_shard_and_gather_fns,\n    with_sharding_constraint, tree_apply, next_rng\n)\nfrom lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM\nfrom lwm.vqgan import VQGAN\n\n\nFLAGS, FLAGS_DEF = define_flags_with_default(\n    prompt='Fireworks over the city',\n    output_file='',\n    temperature_image=1.0,\n    temperature_video=1.0,\n    top_k_image=8192,\n    top_k_video=100,\n    cfg_scale_image=1.0,\n    cfg_scale_video=1.0,\n    vqgan_checkpoint='',\n    n_frames=1,\n    seed=1234,\n    mesh_dim='1,-1,1,1',\n    dtype='fp32',\n    load_llama_config='',\n    update_llama_config='',\n    load_checkpoint='',\n    tokenizer='LargeWorldModel/LWM-Text-1M',\n    llama=VideoLLaMAConfig.get_default_config(),\n    jax_distributed=JaxDistributedConfig.get_default_config(),\n)\n\n\ndef main(argv):\n    assert FLAGS.output_file != ''\n    if FLAGS.output_file.endswith('mp4'):\n        assert FLAGS.n_frames > 1\n    elif FLAGS.output_file.endswith('png') or FLAGS.output_file.endswith('jpg'):\n        assert FLAGS.n_frames == 1\n    else:\n        raise ValueError(f\"Unsupported output file extension: {FLAGS.output_file}\")\n\n    JaxDistributedConfig.initialize(FLAGS.jax_distributed)\n    set_random_seed(FLAGS.seed)\n\n    tokens_per_frame = 257\n    vqgan = VQGAN(FLAGS.vqgan_checkpoint, replicate=False)\n    mesh = VideoLLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)\n    tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer)\n    prefix_tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer, truncation_side='left', padding_side='left')\n    if FLAGS.load_llama_config != '':\n        llama_config = VideoLLaMAConfig.load_config(FLAGS.load_llama_config)\n        updates = VideoLLaMAConfig(**FLAGS.llama)\n        llama_config.update(dict(\n            scan_attention=updates.scan_attention,\n            scan_mlp=updates.scan_mlp,\n            scan_query_chunk_size=updates.scan_query_chunk_size,\n            scan_key_chunk_size=updates.scan_key_chunk_size,\n            scan_mlp_chunk_size=updates.scan_mlp_chunk_size,\n            scan_layers=updates.scan_layers,\n            param_scan_axis=updates.param_scan_axis,\n        ))\n    else:\n        llama_config = VideoLLaMAConfig(**FLAGS.llama)\n\n    if FLAGS.update_llama_config != '':\n        llama_config.update(dict(eval(FLAGS.update_llama_config)))\n\n    llama_config.update(dict(\n        bos_token_id=tokenizer.bos_token_id,\n        eos_token_id=tokenizer.eos_token_id,\n    ))\n    llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))\n\n    with jax.default_device(jax.devices(\"cpu\")[0]):\n        _, params = StreamingCheckpointer.load_trainstate_checkpoint(\n                FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30\n        )\n        model = FlaxVideoLLaMAForCausalLM(\n            llama_config,\n            input_shape=(512, 8192),\n            seed=FLAGS.seed,\n            _do_init=False,\n            dtype=get_float_dtype_by_name(FLAGS.dtype),\n        )\n        model_ps = match_partition_rules(\n            VideoLLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), params\n        )\n        shard_fns, _ = make_shard_and_gather_fns(\n            model_ps, get_float_dtype_by_name(FLAGS.dtype)\n        )\n\n        with mesh:\n            params = tree_apply(shard_fns, params)\n\n    def _forward_generate(params, rng, batch, n_tokens, cfg_scale, top_k, temperature):\n        batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))\n        cfg_scales = jnp.ones((batch['input_ids'].shape[0] // 2,), dtype=jnp.float32) * cfg_scale\n        cfg_scales = with_sharding_constraint(cfg_scales, PS(('dp', 'fsdp')))\n        rng_generator = JaxRNG(rng)\n        output = model.generate_vision(\n            batch['input_ids'],\n            cfg_scales,\n            attention_mask=batch['attention_mask'],\n            vision_masks=batch['vision_masks'],\n            params=params['params'],\n            prng_key=rng_generator(),\n            generation_config=GenerationConfig(\n                max_new_tokens=n_tokens,\n                min_new_tokens=n_tokens,\n                pad_token_id=tokenizer.pad_token_id,\n                temperature=temperature,\n                do_sample=True,\n                top_k=top_k,\n            )\n        ).sequences[:, batch['input_ids'].shape[1]:]\n        return output, rng_generator()\n    _sharded_forward_generate = pjit(\n        _forward_generate,\n        in_shardings=(model_ps, PS(), PS()),\n        out_shardings=(PS(), PS()),\n        static_argnums=(3, 4, 5, 6)\n    )\n\n    # Generate an image or first frame (for video)\n    def generate_first_frame(prompts, max_input_length):\n        nonlocal sharded_rng\n        uncond_prompts = [\"<s><vision>\"] * len(prompts)\n        prompts = prompts + uncond_prompts\n        inputs = prefix_tokenizer(\n            prompts,\n            padding='max_length',\n            truncation=True,\n            max_length=max_input_length,\n            return_tensors='np'\n        )\n        batch = dict(\n            input_ids=inputs.input_ids,\n            attention_mask=inputs.attention_mask,\n            vision_masks=np.zeros(inputs.input_ids.shape, dtype=bool),\n        )\n        with mesh:\n            output, sharded_rng = _sharded_forward_generate(\n                params, sharded_rng, batch,\n                tokens_per_frame, FLAGS.cfg_scale_image,\n                FLAGS.top_k_image, FLAGS.temperature_image\n            )\n            output = jax.device_get(output)\n            output = np.split(output, 2, axis=0)[0]\n        output = output.reshape(len(prompts) // 2, tokens_per_frame)\n        image = vqgan.decode(output[:, :-1].reshape(-1, 16, 16))\n        image = ((jax.device_get(image) + 1) * 127.5).astype(np.uint8)\n        return output, image\n\n    sharded_rng = next_rng()\n    prompts = [FLAGS.prompt]\n    entries = []\n    for prompt in prompts:\n        entries.append({\n            'caption': prompt,\n            'prompt': f\"<s>You are a helpful assistant. USER: Generate an image of {prompt} ASSISTANT: <vision>\",\n        })\n\n    B = 1\n    images, image_encodings = [], []\n    for i in tqdm(list(range(0, len(entries), B))):\n        entries_i = entries[i:i + B]\n        prompts = [entry['prompt'] for entry in entries_i]\n        img_enc, img = generate_first_frame(prompts, max_input_length=128)\n        image_encodings.extend(img_enc)\n        images.extend(img)\n\n    if FLAGS.n_frames == 1:\n        image = images[0]\n        Image.fromarray(image).save(FLAGS.output_file)\n        return\n\n    # Generate the rest of the video\n    def generate_video_pred(prompts, images, max_input_length):\n        nonlocal sharded_rng\n        images = np.concatenate([images, images], axis=0)\n        uncond_prompts = [\"<s><vision>\"] * len(prompts)\n        prompts = prompts + uncond_prompts\n        inputs = prefix_tokenizer(\n            prompts,\n            padding='max_length',\n            truncation=True,\n            max_length=max_input_length,\n            return_tensors='np'\n        )\n        batch = dict(\n            input_ids=np.concatenate([inputs.input_ids, images], axis=1),\n            attention_mask=np.concatenate([inputs.attention_mask, np.ones(images.shape, dtype=inputs.attention_mask.dtype)], axis=1),\n            vision_masks=np.concatenate([\n                np.zeros(inputs.input_ids.shape, dtype=bool),\n                np.ones(images.shape, dtype=bool)\n            ], axis=1),\n        )\n        with mesh:\n            output, sharded_rng = _sharded_forward_generate(\n                params, sharded_rng, batch,\n                (FLAGS.n_frames - 1) * tokens_per_frame, FLAGS.cfg_scale_video,\n                FLAGS.top_k_video, FLAGS.temperature_video\n            )\n            output = jax.device_get(output)\n            output = np.split(output, 2, axis=0)[0]\n        output = output.reshape(len(prompts) // 2, FLAGS.n_frames - 1, tokens_per_frame)\n        output = np.concatenate([images[:len(prompts) // 2, None], output], axis=1)\n        output = output[:, :, :-1].reshape(-1, FLAGS.n_frames, 16, 16)\n        vision = []\n        for v in output:\n            v = vqgan.decode(v)\n            v = ((jax.device_get(v) + 1) * 127.5).astype(np.uint8)\n            vision.append(v)\n        return vision\n\n    new_entries = []\n    for img_enc, entry in zip(image_encodings, entries):\n        new_entries.append({\n            'caption': entry['caption'],\n            'prompt': f\"<s>You are a helpful assistant. USER: Generate a video of {entry['caption']} ASSISTANT: <vision>\",\n            'image': np.array(img_enc, dtype=np.int32),\n        })\n    entries = new_entries\n\n    B = 1\n    videos = []\n    for i in tqdm(list(range(0, len(entries), B))):\n        entries_i = entries[i:i + B]\n        prompts = [entry['prompt'] for entry in entries_i]\n        images = np.array([entry['image'] for entry in entries_i], dtype=np.int32)\n        videos.extend(generate_video_pred(prompts, images, max_input_length=128))\n\n    video = videos[0]\n    writer = imageio.get_writer(FLAGS.output_file, fps=4)\n    for frame in video:\n        writer.append_data(frame)\n    writer.close()\n\n    print('done')\n\nif __name__ == \"__main__\":\n    run(main)\n"
  },
  {
    "path": "lwm/vision_llama.py",
    "content": "from typing import Any, Dict, List, Optional, Tuple, Union\nimport json\nimport warnings\nimport copy\n\nimport jax\nimport jax.numpy as jnp\nfrom jax import lax\nfrom jax.sharding import PartitionSpec as PS\nimport flax.linen as nn\nfrom flax.core.frozen_dict import unfreeze, freeze\nfrom flax.traverse_util import flatten_dict, unflatten_dict\n\nfrom transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput\nfrom transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel\nfrom transformers.generation.flax_utils import SampleState, FlaxLogitsProcessorList, FlaxSampleOutput, logger\nfrom transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward\nfrom transformers import GenerationConfig\n\nfrom tux import load_pickle, open_file\nfrom lwm.llama import LLaMAConfig, LLAMA_STANDARD_CONFIGS, FlaxLLaMABlockCollection, RMSNorm\n\n\nVIDEO_LLAMA_STANDARD_CONFIGS = LLAMA_STANDARD_CONFIGS\n\n\nclass VideoLLaMAConfig(LLaMAConfig):\n    model_type = \"video_llama\"\n\n    def __init__(self, vision_vocab_size=8448, tie_vision_embeddings=False, sample_mode='all', **kwargs):\n        super().__init__(**kwargs)\n        self.vision_vocab_size = vision_vocab_size # 8192 + 256\n        self.tie_vision_embeddings = tie_vision_embeddings\n        self.sample_mode = sample_mode\n\n    @staticmethod\n    def get_partition_rules(scan_layers=False, scan_axis=0):\n        \"\"\"Parition rules are orderd, so that the beginning rules match first.\"\"\"\n        if scan_layers:\n            if scan_axis == 0:\n                return (\n                    # embeddings\n                    (\"transformer/wte/embedding\", PS(\"tp\", (\"fsdp\", \"sp\"))),\n                    (\"transformer/vte/embedding\", PS(\"tp\", (\"fsdp\", \"sp\"))),\n                    # atention\n                    (\"attention/(wq|wk|wv)/kernel\", PS(None, (\"fsdp\", \"sp\"), \"tp\")),\n                    (\"attention/wo/kernel\", PS(None, \"tp\", (\"fsdp\", \"sp\"))),\n                    # mlp\n                    (\"feed_forward/w1/kernel\", PS(None, (\"fsdp\", \"sp\"), \"tp\")),\n                    (\"feed_forward/w2/kernel\", PS(None, \"tp\", (\"fsdp\", \"sp\"))),\n                    (\"feed_forward/w3/kernel\", PS(None, (\"fsdp\", \"sp\"), \"tp\")),\n                    # layer norms\n                    (\"attention_norm/kernel\", PS(None, None)),\n                    (\"ffn_norm/kernel\", PS(None, None)),\n                    # output head\n                    (\"transformer/ln_f/kernel\", PS(None)),\n                    (\"lm_head/kernel\", PS((\"fsdp\", \"sp\"), \"tp\")),\n                    (\"vision_head/kernel\", PS((\"fsdp\", \"sp\"), \"tp\")),\n                    ('.*', PS(None)),\n                )\n            elif scan_axis == 1:\n                return (\n                    # embeddings\n                    (\"transformer/wte/embedding\", PS(\"tp\", (\"fsdp\", \"sp\"))),\n                    (\"transformer/vte/embedding\", PS(\"tp\", (\"fsdp\", \"sp\"))),\n                    # atention\n                    (\"attention/(wq|wk|wv)/kernel\", PS((\"fsdp\", \"sp\"), None, \"tp\")),\n                    (\"attention/wo/kernel\", PS(\"tp\", None, (\"fsdp\", \"sp\"))),\n                    # mlp\n                    (\"feed_forward/w1/kernel\", PS((\"fsdp\", \"sp\"), None, \"tp\")),\n                    (\"feed_forward/w2/kernel\", PS(\"tp\", None, (\"fsdp\", \"sp\"))),\n                    (\"feed_forward/w3/kernel\", PS((\"fsdp\", \"sp\"), None, \"tp\")),\n                    # layer norms\n                    (\"attention_norm/kernel\", PS(None, None)),\n                    (\"ffn_norm/kernel\", PS(None, None)),\n                    # output head\n                    (\"transformer/ln_f/kernel\", PS(None)),\n                    (\"lm_head/kernel\", PS((\"fsdp\", \"sp\"), \"tp\")),\n                    (\"vision_head/kernel\", PS((\"fsdp\", \"sp\"), \"tp\")),\n                    ('.*', PS(None)),\n                )\n            else:\n                raise ValueError(f\"Invalid scan_axis {scan_axis}\")\n        else:\n            return (\n                # embeddings\n                (\"transformer/wte/embedding\", PS(\"tp\", (\"fsdp\", \"sp\"))),\n                (\"transformer/vte/embedding\", PS(\"tp\", (\"fsdp\", \"sp\"))),\n                # atention\n                (\"attention/(wq|wk|wv)/kernel\", PS((\"fsdp\", \"sp\"), \"tp\")),\n                (\"attention/wo/kernel\", PS(\"tp\", (\"fsdp\", \"sp\"))),\n                # mlp\n                (\"feed_forward/w1/kernel\", PS((\"fsdp\", \"sp\"), \"tp\")),\n                (\"feed_forward/w2/kernel\", PS(\"tp\", (\"fsdp\", \"sp\"))),\n                (\"feed_forward/w3/kernel\", PS((\"fsdp\", \"sp\"), \"tp\")),\n                # layer norms\n                (\"attention_norm/kernel\", PS(None)),\n                (\"ffn_norm/kernel\", PS(None)),\n                # output head\n                (\"transformer/ln_f/kernel\", PS(None)),\n                (\"lm_head/kernel\", PS((\"fsdp\", \"sp\"), \"tp\")),\n                (\"vision_head/kernel\", PS((\"fsdp\", \"sp\"), \"tp\")),\n                ('.*', PS(None)),\n            )\n\n    @classmethod\n    def load_config(cls, path):\n        if path in VIDEO_LLAMA_STANDARD_CONFIGS:\n            return cls.from_dict(VIDEO_LLAMA_STANDARD_CONFIGS[path])\n        load_type, load_path = path.split('::', 1)\n        if load_type == 'pickle':\n            return cls.from_dict(load_pickle(load_path)['llama_config'])\n        elif load_type == 'json':\n            with open_file(load_path, 'r') as fin:\n                raw_config = fin.read()\n            return cls.from_dict(json.loads(raw_config))\n        else:\n            raise ValueError(f'Unsupported load config type: {load_type}')\n\n\nclass FlaxVideoLLaMAPreTrainedModel(FlaxPreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = VideoLLaMAConfig\n    base_model_prefix = \"transformer\"\n    module_class: nn.Module = None\n\n    def __init__(\n        self,\n        config: VideoLLaMAConfig,\n        input_shape: Tuple = (4, 1),\n        seed: int = 0,\n        dtype: jnp.dtype = jnp.float32,\n        _do_init: bool = True,\n        **kwargs,\n    ):\n        module = self.module_class(config=config, dtype=dtype, **kwargs)\n        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)\n\n    def init_cache(self, batch_size, max_length):\n        # init input variables to retrieve cache\n        input_ids = jnp.ones((batch_size, max_length))\n        attention_mask = jnp.ones_like(input_ids)\n        segment_ids = jnp.zeros_like(input_ids)\n        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)\n        vision_masks = jnp.ones((batch_size, max_length), dtype=bool)\n\n        init_variables = self.module.init(\n            jax.random.PRNGKey(0), input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False, init_cache=True\n        )\n        return init_variables[\"cache\"]\n\n    def init_weights(self, rng, input_shape, params=None):\n        # init input tensors\n        input_ids = jnp.zeros(input_shape, dtype=\"i4\")\n        attention_mask = jnp.ones_like(input_ids)\n        vision_masks = jnp.ones(input_ids.shape, dtype=bool)\n        segment_ids = jnp.zeros_like(input_ids)\n        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)\n        params_rng, dropout_rng = jax.random.split(rng)\n        rngs = {\"params\": params_rng, \"dropout\": dropout_rng}\n\n        random_params = self.module.init(rngs, input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False)[\"params\"]\n\n        if params is not None:\n            random_params = flatten_dict(unfreeze(random_params))\n            params = flatten_dict(unfreeze(params))\n            for missing_key in self._missing_keys:\n                params[missing_key] = random_params[missing_key]\n            self._missing_keys = set()\n            return freeze(unflatten_dict(params))\n        else:\n            return random_params\n\n    @add_start_docstrings_to_model_forward(\"\")\n    def __call__(\n        self,\n        input_ids,\n        vision_masks,\n        attention_mask=None,\n        segment_ids=None,\n        position_ids=None,\n        params: dict = None,\n        past_key_values: dict = None,\n        dropout_rng: jax.random.PRNGKey = None,\n        train: bool = False,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ):\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.return_dict\n\n        batch_size, sequence_length = input_ids.shape\n\n        if position_ids is None:\n            if past_key_values is not None:\n                raise ValueError(\"Make sure to provide `position_ids` when passing `past_key_values`.\")\n\n            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))\n\n        if attention_mask is None:\n            attention_mask = jnp.ones((batch_size, sequence_length))\n\n        if segment_ids is None:\n            segment_ids = jnp.zeros((batch_size, sequence_length))\n\n        # Handle any PRNG if needed\n        rngs = {}\n        if dropout_rng is not None:\n            rngs[\"dropout\"] = dropout_rng\n\n        inputs = {\"params\": params or self.params}\n\n        if past_key_values:\n            inputs[\"cache\"] = past_key_values\n            mutable = [\"cache\"]\n        else:\n            mutable = False\n\n        outputs = self.module.apply(\n            inputs,\n            jnp.array(input_ids, dtype=\"i4\"),\n            jnp.array(vision_masks, dtype=\"f4\"),\n            jnp.array(attention_mask, dtype=\"i4\"),\n            jnp.array(segment_ids, dtype=\"i4\"),\n            jnp.array(position_ids, dtype=\"i4\"),\n            not train,\n            False,\n            output_attentions,\n            output_hidden_states,\n            return_dict,\n            rngs=rngs,\n            mutable=mutable,\n        )\n\n        # add updated cache to model output\n        if past_key_values is not None and return_dict:\n            outputs, past_key_values = outputs\n            outputs[\"past_key_values\"] = unfreeze(past_key_values[\"cache\"])\n            return outputs\n        elif past_key_values is not None and not return_dict:\n            outputs, past_key_values = outputs\n            outputs = outputs[:1] + (unfreeze(past_key_values[\"cache\"]),) + outputs[1:]\n\n        return outputs\n\n\nclass FlaxVideoLLaMAModule(nn.Module):\n    config: VideoLLaMAConfig\n    dtype: jnp.dtype = jnp.float32\n    param_dtype: jnp.dtype=jnp.float32\n    precision: Optional[Union[jax.lax.Precision, str]]=None\n\n    def setup(self):\n        self.embed_dim = self.config.hidden_size\n\n        self.vte = nn.Embed(\n            self.config.vision_vocab_size,\n            self.config.hidden_size,\n            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n        )\n\n        self.wte = nn.Embed(\n            self.config.vocab_size,\n            self.config.hidden_size,\n            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n        )\n        self.dropout = nn.Dropout(rate=self.config.embd_pdrop)\n        self.h = FlaxLLaMABlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision)\n        self.ln_f = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype)\n\n    def __call__(\n        self,\n        input_ids,\n        vision_masks,\n        attention_mask,\n        segment_ids,\n        position_ids,\n        deterministic=True,\n        init_cache: bool = False,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n    ):\n        input_ids = input_ids.astype(\"i4\")\n\n        if input_ids.shape[1] == 1:\n            if self.config.sample_mode == 'text':\n                input_embeds = self.wte(input_ids)\n            elif self.config.sample_mode == 'vision':\n                input_embeds = self.vte(input_ids)\n            elif self.config.sample_mode == 'all':\n                raise NotImplementedError\n            else:\n                raise ValueError(f\"Invalid sample_mode: {self.config.sample_mode}\")\n        else:\n            input_text_embeds = self.wte(jnp.where(vision_masks, 0, input_ids))\n            input_vision_embeds = self.vte(jnp.where(vision_masks, input_ids, 0))\n            vision_masks = vision_masks[..., None].astype(\"f4\") # 1 is vision, 0 is text\n            input_embeds = input_text_embeds * (1 - vision_masks) + input_vision_embeds * vision_masks\n\n        hidden_states = self.dropout(input_embeds, deterministic=deterministic)\n\n        outputs = self.h(\n            hidden_states,\n            attention_mask,\n            segment_ids,\n            position_ids=position_ids,\n            deterministic=deterministic,\n            init_cache=init_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = outputs[0]\n        hidden_states = self.ln_f(hidden_states)\n\n        if output_hidden_states:\n            all_hidden_states = outputs[1] + (hidden_states,)\n            outputs = (hidden_states, all_hidden_states) + outputs[2:]\n        else:\n            outputs = (hidden_states,) + outputs[1:]\n\n        if not return_dict:\n            return tuple(v for v in outputs if v is not None)\n\n        return FlaxBaseModelOutput(\n            last_hidden_state=hidden_states,\n            hidden_states=outputs[1],\n            attentions=outputs[-1],\n        )\n\n\nclass FlaxVideoLLaMAForCausalLMModule(nn.Module):\n    config: VideoLLaMAConfig\n    dtype: jnp.dtype = jnp.float32\n    param_dtype: jnp.dtype=jnp.float32\n    precision: Optional[Union[jax.lax.Precision, str]]=None\n\n    def setup(self):\n        self.transformer = FlaxVideoLLaMAModule(self.config, dtype=self.dtype)\n        self.vision_head = nn.Dense(\n            self.config.vision_vocab_size,\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n            use_bias=False,\n            kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),\n            precision=self.precision,\n        )\n        self.lm_head = nn.Dense(\n            self.config.vocab_size,\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n            use_bias=False,\n            kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),\n            precision=self.precision,\n        )\n\n    def __call__(\n        self,\n        input_ids,\n        vision_masks,\n        attention_mask=None,\n        segment_ids=None,\n        position_ids=None,\n        deterministic: bool = True,\n        init_cache: bool = False,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n    ):\n        batch_size, seq_length = input_ids.shape\n        if attention_mask is None:\n            attention_mask = jnp.ones_like(input_ids)\n        if segment_ids is None:\n            segment_ids = jnp.zeros_like(input_ids)\n        if position_ids is None:\n            position_ids = jnp.broadcast_to(\n                jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),\n                (batch_size, seq_length)\n            )\n\n\n        outputs = self.transformer(\n            input_ids,\n            vision_masks,\n            attention_mask,\n            segment_ids,\n            position_ids,\n            deterministic=deterministic,\n            init_cache=init_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = outputs[0]\n\n        if self.config.tie_vision_embeddings:\n            shared_kernel = self.transformer.variables[\"params\"][\"vte\"][\"embedding\"].T\n            vision_logits = self.vision_head.apply({\"params\": {\"kernel\": shared_kernel}}, hidden_states)\n        else:\n            vision_logits = self.vision_head(hidden_states)\n\n        if self.config.tie_word_embeddings:\n            shared_kernel = self.transformer.variables[\"params\"][\"wte\"][\"embedding\"].T\n            lm_logits = self.lm_head.apply({\"params\": {\"kernel\": shared_kernel}}, hidden_states)\n        else:\n            lm_logits = self.lm_head(hidden_states)\n\n        if self.config.sample_mode == 'all':\n            if not return_dict:\n                return (vision_logits, lm_logits,) + outputs[1:]\n\n            return FlaxCausalLMOutput(logits=(vision_logits, lm_logits), hidden_states=outputs.hidden_states, attentions=outputs.attentions)\n        elif self.config.sample_mode == 'vision':\n            if not return_dict:\n                return (vision_logits,) + outputs[1:]\n\n            return FlaxCausalLMOutput(logits=vision_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)\n        elif self.config.sample_mode == 'text':\n            if not return_dict:\n                return (lm_logits,) + outputs[1:]\n\n            return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)\n        else:\n            raise ValueError(f\"Invalid sample_mode: {self.config.sample_mode}\")\n\n\n\n@add_start_docstrings(\"\", \"\")\nclass FlaxVideoLLaMAForCausalLM(FlaxVideoLLaMAPreTrainedModel):\n    module_class = FlaxVideoLLaMAForCausalLMModule\n\n    def prepare_inputs_for_generation(\n        self, input_ids, max_length, attention_mask: Optional[jax.Array] = None, vision_masks = None\n    ):\n        # initializing the cache\n        batch_size, seq_length = input_ids.shape\n\n        past_key_values = self.init_cache(batch_size, max_length)\n        extended_attention_mask = jnp.ones((batch_size, max_length), dtype=\"i4\")\n        if attention_mask is not None:\n            position_ids = attention_mask.cumsum(axis=-1) - 1\n            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))\n        else:\n            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype=\"i4\")[None, :], (batch_size, seq_length))\n\n        return {\n            \"past_key_values\": past_key_values,\n            \"attention_mask\": extended_attention_mask,\n            \"position_ids\": position_ids,\n            \"vision_masks\": vision_masks\n        }\n\n    def update_inputs_for_generation(self, model_outputs, model_kwargs):\n        return {\n            \"past_key_values\":  model_outputs.past_key_values,\n            \"position_ids\": model_kwargs[\"position_ids\"][:, -1:] + 1,\n            \"attention_mask\": model_kwargs[\"attention_mask\"],\n            \"vision_masks\": model_kwargs[\"vision_masks\"]\n        }\n\n    def _sample_vision(\n        self,\n        input_ids: None,\n        max_length: Optional[int] = None,\n        pad_token_id: Optional[int] = None,\n        eos_token_id: Optional[int] = None,\n        prng_key: Optional[jnp.ndarray] = None,\n        logits_processor: Optional[FlaxLogitsProcessorList] = None,\n        logits_warper: Optional[FlaxLogitsProcessorList] = None,\n        cfg_scales: jnp.ndarray = 1.0,\n        trace: bool = True,\n        params: Optional[Dict[str, jnp.ndarray]] = None,\n        model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,\n    ):\n        # init values\n        max_length = max_length if max_length is not None else self.generation_config.max_length\n        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id\n        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id\n        prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)\n\n        batch_size, cur_len = input_ids.shape\n        initial_len = cur_len\n\n        eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)\n        pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)\n        cur_len = jnp.array(cur_len)\n\n        # per batch-item holding current token in loop.\n        sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)\n        sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))\n\n        # per batch-item state bit indicating if sentence has finished.\n        is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)\n\n        # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop\n        # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.\n        model = self.decode if self.config.is_encoder_decoder else self\n\n        # initialize model specific kwargs\n        model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)\n\n        # initialize state\n        state = SampleState(\n            cur_len=cur_len,\n            sequences=sequences,\n            running_token=input_ids,\n            is_sent_finished=is_sent_finished,\n            prng_key=prng_key,\n            model_kwargs=model_kwargs,\n        )\n\n        def sample_search_cond_fn(state):\n            \"\"\"state termination condition fn.\"\"\"\n            has_reached_max_length = state.cur_len == max_length\n            all_sequence_finished = jnp.all(state.is_sent_finished)\n            finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)\n            return ~finish_generation\n\n        def sample_search_body_fn(state):\n            \"\"\"state update fn.\"\"\"\n            prng_key, prng_key_next = jax.random.split(state.prng_key)\n            model_outputs = model(state.running_token, params=params, **state.model_kwargs)\n\n            logits = model_outputs.logits[:, -1]\n            cond_logits, uncond_logits = jnp.split(logits, 2, axis=0)\n            logits = uncond_logits + cfg_scales[:, None] * (cond_logits - uncond_logits)\n\n            # apply min_length, ...\n            logits = logits_processor(state.sequences, logits, state.cur_len)\n            # apply top_p, top_k, temperature\n            logits = logits_warper(logits, logits, state.cur_len)\n\n            next_token = jax.random.categorical(prng_key, logits, axis=-1)\n            next_token = jax.lax.cond(\n                (state.cur_len - initial_len + 1) % 257 == 0,\n                lambda: jnp.full_like(next_token, 8192),\n                lambda: next_token\n            )\n            next_token = jnp.concatenate([next_token, next_token], axis=0)\n\n            #next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished\n            next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)\n            next_token = next_token[:, None]\n\n            next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))\n            next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)\n\n            return SampleState(\n                cur_len=state.cur_len + 1,\n                sequences=next_sequences,\n                running_token=next_token,\n                is_sent_finished=next_is_sent_finished,\n                model_kwargs=next_model_kwargs,\n                prng_key=prng_key_next,\n            )\n\n        # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU\n        if input_ids.shape[1] > 1:\n            state = sample_search_body_fn(state)\n\n        if not trace:\n            state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)\n        else:\n            state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)\n\n        return FlaxSampleOutput(sequences=state.sequences)\n\n    def generate_vision(\n        self,\n        input_ids: jnp.ndarray,\n        cfg_scales: jnp.ndarray,\n        generation_config: Optional[GenerationConfig] = None,\n        prng_key: Optional[jnp.ndarray] = None,\n        trace: bool = True,\n        params: Optional[Dict[str, jnp.ndarray]] = None,\n        logits_processor: Optional[FlaxLogitsProcessorList] = None,\n        **kwargs,\n    ):\n        # Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call\n        self._validate_model_class()\n\n        # priority: `generation_config` argument > `model.generation_config` (the default generation config)\n        if generation_config is None:\n            # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,\n            # two conditions must be met\n            # 1) the generation config must have been created from the model config (`_from_model_config` field);\n            # 2) the generation config must have seen no modification since its creation (the hash is the same).\n            if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(\n                self.generation_config\n            ):\n                new_generation_config = GenerationConfig.from_model_config(self.config)\n                if new_generation_config != self.generation_config:\n                    warnings.warn(\n                        \"You have modified the pretrained model configuration to control generation. This is a\"\n                        \" deprecated strategy to control generation and will be removed soon, in a future version.\"\n                        \" Please use and modify the model generation configuration (see\"\n                        \" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )\"\n                    )\n                    self.generation_config = new_generation_config\n            generation_config = self.generation_config\n\n        generation_config = copy.deepcopy(generation_config)\n        model_kwargs = generation_config.update(**kwargs)  # All unused kwargs must be model kwargs\n        generation_config.validate()\n        self._validate_model_kwargs(model_kwargs.copy())\n\n        logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList()\n\n        # set init values\n        prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)\n\n        if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:\n            if model_kwargs.get(\"attention_mask\") is None:\n                logger.warning(\n                    \"The attention mask and the pad token id were not set. As a consequence, you may observe \"\n                    \"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\"\n                )\n            eos_token_id = generation_config.eos_token_id\n            if isinstance(eos_token_id, list):\n                eos_token_id = eos_token_id[0]\n            logger.warning(f\"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.\")\n            generation_config.pad_token_id = eos_token_id\n\n        if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder:\n            raise ValueError(\"`decoder_start_token_id` has to be defined for encoder-decoder generation.\")\n\n        # decoder-only models should use left-padding for generation (can't be checked with `trace=True`)\n        if not self.config.is_encoder_decoder and not trace:\n            if (\n                generation_config.pad_token_id is not None\n                and jnp.sum(input_ids[:, -1] == generation_config.pad_token_id) > 0\n            ):\n                logger.warning(\n                    \"A decoder-only architecture is being used, but right-padding was detected! For correct \"\n                    \"generation results, please set `padding_side='left'` when initializing the tokenizer.\"\n                )\n\n        batch_size = input_ids.shape[0]\n\n        if self.config.is_encoder_decoder:\n            # add encoder_outputs to model_kwargs\n            if model_kwargs.get(\"encoder_outputs\") is None:\n                model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)\n            # prepare decoder_input_ids for generation\n            input_ids = self._prepare_decoder_input_ids_for_generation(\n                batch_size,\n                decoder_start_token_id=generation_config.decoder_start_token_id,\n                bos_token_id=generation_config.bos_token_id,\n                model_kwargs=model_kwargs,\n            )\n\n        # Prepare `max_length` depending on other stopping criteria.\n        input_ids_seq_length = input_ids.shape[-1]\n        has_default_max_length = kwargs.get(\"max_length\") is None and generation_config.max_length is not None\n        if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:\n            # 20 is the default max_length of the generation config\n            warnings.warn(\n                f\"Using the model-agnostic default `max_length` (={generation_config.max_length}) \"\n                \"to control the generation length.  recommend setting `max_new_tokens` to control the maximum length of the generation.\",\n                UserWarning,\n            )\n        elif generation_config.max_new_tokens is not None:\n            if not has_default_max_length and generation_config.max_length is not None:\n                logger.warning(\n                    f\"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=\"\n                    f\"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. \"\n                    \"Please refer to the documentation for more information. \"\n                    \"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)\"\n                )\n            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length\n\n        if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:\n            raise ValueError(\n                f\"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger than\"\n                f\" the maximum length ({generation_config.max_length})\"\n            )\n        if input_ids_seq_length >= generation_config.max_length:\n            input_ids_string = \"decoder_input_ids\" if self.config.is_encoder_decoder else \"input_ids\"\n            logger.warning(\n                f\"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to\"\n                f\" {generation_config.max_length}. This can lead to unexpected behavior. You should consider\"\n                \" increasing`max_new_tokens`.\"\n            )\n\n        logits_processor = self._get_logits_processor(\n            generation_config=generation_config,\n            input_ids_seq_length=input_ids_seq_length,\n            logits_processor=logits_processor,\n        )\n\n        if not generation_config.do_sample and generation_config.num_beams == 1:\n            raise NotImplementedError\n        elif generation_config.do_sample and generation_config.num_beams == 1:\n            logits_warper = self._get_logits_warper(generation_config=generation_config)\n            return self._sample_vision(\n                input_ids,\n                generation_config.max_length,\n                generation_config.pad_token_id,\n                generation_config.eos_token_id,\n                prng_key,\n                logits_warper=logits_warper,\n                logits_processor=logits_processor,\n                cfg_scales=cfg_scales,\n                trace=trace,\n                params=params,\n                model_kwargs=model_kwargs,\n            )\n        elif not generation_config.do_sample and generation_config.num_beams > 1:\n            raise NotImplementedError\n        else:\n            raise NotImplementedError(\"`Beam sampling is currently not implemented.\")\n"
  },
  {
    "path": "lwm/vqgan.py",
    "content": "from typing import Optional\nfrom functools import cached_property, partial\nimport pickle\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nimport flax.linen as nn\nfrom flax import jax_utils\nfrom transformers.configuration_utils import PretrainedConfig\nfrom ml_collections import ConfigDict\nfrom tux import function_args_to_config, open_file\n\n\nclass VQGAN:\n    def __init__(self, vqgan_checkpoint, replicate=False):\n        assert vqgan_checkpoint != ''\n        self.replicate = replicate\n        self.config = VQGANConfig.get_default_config()\n        self.params = pickle.load(open_file(vqgan_checkpoint, 'rb'))\n        if replicate:\n            self.params = jax_utils.replicate(self.params)\n        else:\n            self.params = jax.jit(lambda x: x)(self.params)\n        self.model = VQGANModel(self.config)\n\n    def _wrap_fn(self, fn):\n        if self.replicate:\n            return jax.pmap(fn, devices=jax.local_devices())\n        else:\n            return jax.jit(fn)\n    \n    @cached_property\n    def _encode(self):\n        def fn(pixel_values, params):\n            return self.model.apply(\n                {'params': params}, \n                pixel_values,\n                method=self.model.encode\n            )\n        return partial(self._wrap_fn(fn), params=self.params)\n    \n    @cached_property\n    def _decode(self):\n        def fn(encoding, params):\n            return self.model.apply(\n                {'params': params},\n                encoding,\n                method=self.model.decode\n            )\n        return partial(self._wrap_fn(fn), params=self.params)\n    \n    def encode(self, pixel_values):\n        return self._encode(pixel_values)\n    \n    def decode(self, encoding):\n        return self._decode(encoding)\n    \n\nclass VQGANConfig(PretrainedConfig):\n    model_type = \"vqgan\"\n\n    def __init__(\n        self,\n        resolution=256,\n        num_channels=3,\n        hidden_channels=128,\n        channel_mult=(1, 2, 2, 4, 6),\n        num_res_blocks=2,\n        attn_resolutions=(),\n        no_attn_mid_block=True,\n        z_channels=64,\n        num_embeddings=8192,\n        quantized_embed_dim=64,\n        dropout=0.0,\n        resample_with_conv=True,\n        commitment_cost=0.25\n    ):\n        self.resolution = resolution\n        self.num_channels = num_channels\n        self.hidden_channels = hidden_channels\n        self.channel_mult = channel_mult\n        self.num_res_blocks = num_res_blocks\n        self.attn_resolutions = attn_resolutions\n        self.no_attn_mid_block = no_attn_mid_block\n        self.z_channels = z_channels\n        self.num_embeddings = num_embeddings\n        self.quantized_embed_dim = quantized_embed_dim\n        self.dropout = dropout\n        self.resample_with_conv = resample_with_conv\n        self.commitment_cost = commitment_cost\n    \n    @classmethod\n    def get_default_config(cls, updates=None):\n        config = function_args_to_config(cls.__init__)\n        if updates is not None:\n            config.update(ConfigDict(updates).copy_and_resolve_references())\n        config.num_resolutions = len(config.channel_mult)\n        return config\n    \n    @classmethod\n    def load_config(cls, path):\n        return cls.get_default_config(cls)\n\n        \nclass VQGANModel(nn.Module):\n    config: VQGANConfig\n\n    def setup(self):\n        self.encoder = Encoder(self.config)\n        self.decoder = Decoder(self.config)\n        self.quantize = VectorQuantizer(\n            self.config.num_embeddings, self.config.quantized_embed_dim\n        )\n        self.quant_conv = nn.Conv(self.config.quantized_embed_dim, [1, 1])\n        self.post_quant_conv = nn.Conv(self.config.z_channels, [1, 1])\n    \n    def encode(self, pixel_values):\n        T = None\n        if len(pixel_values.shape) == 5: # video\n            T = pixel_values.shape[1]\n            pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:])\n        hidden_states = self.encoder(pixel_values)\n        hidden_states = self.quant_conv(hidden_states)\n        quantized_states, codebook_indices = self.quantize(hidden_states)\n        if T is not None:\n            quantized_states = quantized_states.reshape(-1, T, *quantized_states.shape[1:])\n            codebook_indices = codebook_indices.reshape(-1, T, *codebook_indices.shape[1:])\n        return quantized_states, codebook_indices\n\n    def decode(self, encoding, is_codebook_indices=True):\n        if is_codebook_indices:\n            encoding = self.quantize(None, encoding)\n        T = None\n        if len(encoding.shape) == 5:\n            T = encoding.shape[1]\n            encoding = encoding.reshape(-1, *encoding.shape[2:])\n        hidden_states = self.post_quant_conv(encoding)\n        reconstructed_pixel_values = self.decoder(hidden_states)\n        if T is not None:\n            reconstructed_pixel_values = reconstructed_pixel_values.reshape(-1, T, *reconstructed_pixel_values.shape[1:])\n        return jnp.clip(reconstructed_pixel_values, -1, 1)\n    \n    def __call__(self, pixel_values):\n        encoding = self.encode(pixel_values)[1]\n        recon = self.decode(encoding)\n        return recon\n    \n\nclass Encoder(nn.Module):\n    config: VQGANConfig\n    \n    @nn.compact\n    def __call__(self, pixel_values):\n        assert pixel_values.shape[1] == pixel_values.shape[2] == self.config.resolution, pixel_values.shape\n        hidden_states = nn.Conv(self.config.hidden_channels, [3, 3])(pixel_values)\n        for i_level in range(self.config.num_resolutions):\n            hidden_states = DownsamplingBlock(self.config, i_level)(hidden_states)\n        hidden_states = MidBlock(\n            self.config, self.config.no_attn_mid_block, self.config.dropout\n        )(hidden_states)\n        hidden_states = nn.GroupNorm()(hidden_states)\n        hidden_states = nn.silu(hidden_states)\n        hidden_states = nn.Conv(self.config.z_channels, [3, 3])(hidden_states)\n        return hidden_states\n\n        \nclass Decoder(nn.Module):\n    config: VQGANConfig\n\n    @nn.compact\n    def __call__(self, hidden_states):\n        hidden_states = nn.Conv(\n            self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1],\n            [3, 3]\n        )(hidden_states)\n        hidden_states = MidBlock(\n            self.config, self.config.no_attn_mid_block, self.config.dropout\n        )(hidden_states)\n        for i_level in reversed(range(self.config.num_resolutions)):\n            hidden_states = UpsamplingBlock(self.config, i_level)(hidden_states)\n        hidden_states = nn.GroupNorm()(hidden_states)\n        hidden_states = nn.silu(hidden_states)\n        hidden_states = nn.Conv(self.config.num_channels, [3, 3])(hidden_states)\n        return hidden_states\n\n\nclass VectorQuantizer(nn.Module):\n    n_e: int\n    e_dim: int\n\n    @nn.compact\n    def __call__(self, z, encoding_indices=None):\n        def quantize(encoding_indices):\n            w = jax.device_put(embeddings)\n            return w[(encoding_indices,)]\n        embeddings = self.param(\n            'embeddings',\n            lambda rng, shape, dtype: jax.random.uniform(\n                rng, shape, dtype, minval=-1.0 / self.n_e, maxval=1.0 / self.n_e\n            ),\n            [self.n_e, self.e_dim], jnp.float32\n        )\n        \n        if encoding_indices is not None:\n            return quantize(encoding_indices)\n\n        z_flattened = z.reshape(-1, z.shape[-1])\n        d = jnp.sum(z_flattened ** 2, axis=1, keepdims=True) + \\\n            jnp.sum(embeddings.T ** 2, axis=0, keepdims=True) - \\\n            2 * jnp.einsum('bd,nd->bn', z_flattened, embeddings)\n        \n        min_encoding_indices = jnp.argmin(d, axis=1)\n        z_q = quantize(min_encoding_indices)\n        z_q = jnp.reshape(z_q, z.shape)\n        z_q = z + jax.lax.stop_gradient(z_q - z)\n\n        encodings_one_hot = jax.nn.one_hot(min_encoding_indices, num_classes=self.n_e)\n        assert len(encodings_one_hot.shape) == 2\n        min_encoding_indices = jnp.reshape(min_encoding_indices, z.shape[:-1])\n\n        return z_q, min_encoding_indices\n\n\nclass DownsamplingBlock(nn.Module):\n    config: VQGANConfig\n    block_idx: int\n\n    @nn.compact\n    def __call__(self, hidden_states):\n        block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]\n        for _ in range(self.config.num_res_blocks):\n            hidden_states = ResnetBlock(\n                block_out, dropout_prob=self.config.dropout\n            )(hidden_states) \n            if hidden_states.shape[1] in self.config.attn_resolutions:\n                hidden_states = AttnBlock()(hidden_states)\n        if self.block_idx != self.config.num_resolutions - 1:\n            hidden_states = Downsample(self.config.resample_with_conv)(hidden_states)\n        return hidden_states\n\n\nclass ResnetBlock(nn.Module):\n    out_channels: Optional[int] = None\n    use_conv_shortcut: bool = False\n    dropout_prob: float = 0.0\n\n    @nn.compact\n    def __call__(self, hidden_states):\n        out_channels = self.out_channels or hidden_states.shape[-1]\n        residual = hidden_states\n        hidden_states = nn.GroupNorm()(hidden_states)\n        hidden_states = nn.silu(hidden_states)\n        hidden_states = nn.Conv(out_channels, [3, 3])(hidden_states)\n        hidden_states = nn.GroupNorm()(hidden_states)\n        hidden_states = nn.silu(hidden_states)\n        hidden_states = nn.Dropout(self.dropout_prob, deterministic=True)(hidden_states)\n        hidden_states = nn.Conv(out_channels, [3, 3])(hidden_states)\n        if out_channels != residual.shape[-1]:\n            if self.use_conv_shortcut:\n                residual = nn.Conv(out_channels, [3, 3])(residual)\n            else:\n                residual = nn.Conv(out_channels, [1, 1])(residual)\n        return hidden_states + residual\n        \n\nclass AttnBlock(nn.Module):\n    @nn.compact\n    def __call__(self, hidden_states):\n        residual = hidden_states\n        hidden_states = nn.GroupNorm()(hidden_states)\n        query = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states)\n        key = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states)\n        value = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states)\n        query, key, value = map(\n            lambda x: x.reshape(x.shape[0], -1, x.shape[-1]),\n            [query, key, value]\n        )\n        attn_weights = jnp.einsum(\"bqd,bkd->bqk\", query, key)\n        attn_weights *= hidden_states.shape[-1] ** -0.5\n        attn_weights = jax.nn.softmax(attn_weights, axis=-1)\n        hidden_states = jnp.einsum(\"bqk,bkd->bqd\", attn_weights, value)\n        hidden_states = nn.Conv(hidden_states.shape[-1], [1, 1])(hidden_states)\n        return hidden_states + residual\n\n        \nclass Downsample(nn.Module):\n    with_conv: bool\n    \n    @nn.compact\n    def __call__(self, hidden_states):\n        if self.with_conv:\n            hidden_states = jnp.pad(\n                hidden_states,\n                [(0, 0), (0, 1), (0, 1), (0, 0)]\n            )\n            hidden_states = nn.Conv(\n                hidden_states.shape[-1], [3, 3], \n                strides=[2, 2], \n                padding=\"VALID\"\n            )(hidden_states)\n        else:\n            hidden_states = nn.avg_pool(hidden_states, [2, 2], [2, 2])\n        return hidden_states\n\n        \nclass Upsample(nn.Module):\n    with_conv: bool\n\n    @nn.compact\n    def __call__(self, hidden_states):\n        B, H, W, C = hidden_states.shape\n        hidden_states = jax.image.resize(\n            hidden_states,\n            (B, H * 2, W * 2, C),\n            method=\"nearest\"\n        )\n        if self.with_conv:\n            hidden_states = nn.Conv(hidden_states.shape[-1], [3, 3])(hidden_states)\n        return hidden_states\n\n\nclass UpsamplingBlock(nn.Module):\n    config: VQGANConfig\n    block_idx: int\n\n    @nn.compact\n    def __call__(self, hidden_states):\n        block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]\n        for _ in range(self.config.num_res_blocks + 1):\n            hidden_states = ResnetBlock(\n                block_out, dropout_prob=self.config.dropout\n            )(hidden_states)\n            if hidden_states.shape[1] in self.config.attn_resolutions:\n                hidden_states = AttnBlock()(hidden_states)\n        if self.block_idx != 0:\n            hidden_states = Upsample(self.config.resample_with_conv)(hidden_states)\n        return hidden_states\n\n\nclass MidBlock(nn.Module):\n    config: VQGANConfig\n    no_attn: bool\n    dropout: float\n\n    @nn.compact\n    def __call__(self, hidden_states):\n        hidden_states = ResnetBlock(dropout_prob=self.dropout)(hidden_states)\n        if not self.no_attn:\n            hidden_states = AttnBlock()(hidden_states)\n        hidden_states = ResnetBlock(dropout_prob=self.dropout)(hidden_states)\n        return hidden_states\n"
  },
  {
    "path": "scripts/create_needle_data.py",
    "content": "import os\nimport argparse\nimport json\nfrom tqdm import tqdm\nfrom datasets import load_dataset\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--output_path\", type=str, default=\"data/pg19.jsonl\")\nargs = parser.parse_args()\n\nos.makedirs(os.path.dirname(args.output_path), exist_ok=True)\n\ndset = load_dataset(\"pg19\")[\"train\"]\nwith open(args.output_path, \"w\") as f:\n    for elem in tqdm(dset):\n        data = {\"text\": elem[\"text\"]}\n        f.write(f\"{json.dumps(data)}\\n\")"
  },
  {
    "path": "scripts/eval_needle.py",
    "content": "from absl.app import run\nimport time\nimport json\nimport math\nimport os\nfrom tqdm import tqdm\nimport random\nfrom functools import cached_property\nimport numpy as np\nimport jax\nfrom jax.experimental.pjit import pjit\nfrom jax.sharding import PartitionSpec as PS\nimport gcsfs\nimport tiktoken\nfrom transformers import GenerationConfig, AutoTokenizer\nfrom tux import (\n    define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,\n    set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng,\n    match_partition_rules, make_shard_and_gather_fns,\n    with_sharding_constraint, tree_apply, open_file\n)\nfrom lwm.llama import LLaMAConfig, FlaxLLaMAForCausalLM\n\n\nFLAGS, FLAGS_DEF = define_flags_with_default(\n    haystack_file=\"\",\n    max_tokens_per_batch=2000000,\n    output_file=\"results.json\",\n    context_lengths_min=1000,\n    context_lengths_max=32000,\n    n_context_length_intervals=3,\n    n_document_depth_intervals=3,\n    n_rounds=2,\n    seed=1234,\n    mesh_dim='1,-1,1,1',\n    dtype='fp32',\n    load_llama_config='',\n    update_llama_config='',\n    load_checkpoint='',\n    tokenizer='LargeWorldModel/LWM-Text-1M',\n    checkpointer=StreamingCheckpointer.get_default_config(),\n    llama=LLaMAConfig.get_default_config(),\n    jax_distributed=JaxDistributedConfig.get_default_config(),\n)\n\n\nclass LLMNeedleHaystackTester:\n    OURS_TEMPLATE = \"You are a helpful assistant. USER: {context} {question} Don't give information outside the document or repeat your findings. Keep your response short and direct. ASSISTANT: \"\n    RANDOM_NEEDLE_CITIES  = [\n        'Chicago', 'Yangon', 'Antananarivo', 'Colombo', 'Almaty', 'Sydney', 'Chicago', 'Mexico City',\n        'Seattle', 'Lagos', 'Amsterdam', 'Belgrade', 'Cairo', 'Baghdad', 'Damascus', 'Kigali', 'Dakar',\n        'Dakar', 'Sofia', 'Kigali', 'Victoria', 'Tashkent', 'Mumbai', 'Barcelona', 'Almaty', 'Amman',\n        'Toronto', 'Bratislava', 'Johannesburg', 'Thimphu', 'Bangkok', 'Santiago', 'Cairo', 'San Francisco',\n        'Lagos', 'Amsterdam', 'Paris', 'Rabat', 'Santiago', 'Copenhagen', 'Madrid', 'Kigali',\n        'Ho Chi Minh City', 'Sarajevo', 'Delhi', 'Istanbul', 'Ho Chi Minh City', 'Khartoum', 'Helsinki',\n        'Doha', 'Istanbul', 'Kuala Lumpur', 'Budapest', 'Shanghai', 'Moscow', 'Los Angeles', 'Oslo',\n        'Johannesburg', 'Berlin', 'Bangalore', 'Tokyo', 'Melbourne', 'Barcelona', 'Chicago', 'Port Louis',\n        'Lisbon', 'Nairobi', 'Kampala', 'Lima', 'Maputo', 'Vancouver', 'Dubai', 'Khartoum', 'Jakarta',\n        'Madrid', 'Yerevan', 'Beirut', 'Athens', 'Chicago', 'Paris', 'Bucharest', 'Copenhagen', 'Brussels',\n        'Damascus', 'Seattle', 'Los Angeles', 'Yerevan', 'Victoria', 'Tunis', 'Astana', 'Seoul',\n        'Buenos Aires', 'Bangkok', 'Colombo', 'Brussels', 'Khartoum', 'Doha', 'San Francisco', 'Vienna', 'Jakarta'\n    ]\n\n    def __init__(self,\n                 needle=\"\",\n                 haystack_file=\"\",\n                 retrieval_question=\"What is the special magic {} number?\",\n                 results_version = 1,\n                 rnd_number_digits = 7,\n                 context_lengths_min = 1000,\n                 context_lengths_max = 126000,\n                 context_lengths_num_intervals = 10,\n                 document_depth_percent_min = 0,\n                 document_depth_percent_max = 100,\n                 document_depth_percent_intervals = 10,\n                 document_depth_percent_interval_type = \"linear\",\n                 save_results = False,\n                 final_context_length_buffer = 200,\n                 print_ongoing_status = True):\n        needle=\"\\nThe special magic {city} number is: {rnd_number}\\n\"\n        self.needle = needle\n        if not needle or not haystack_file or not retrieval_question:\n            raise ValueError(\"Needle, haystack, and retrieval_question must be provided.\")\n\n        self.rnd_number_digits = rnd_number_digits\n        self.context_lengths_num_intervals = context_lengths_num_intervals\n        self.document_depth_percent_intervals = document_depth_percent_intervals\n        self.haystack_file = haystack_file\n        self.retrieval_question = retrieval_question\n        self.results_version = results_version\n        self.save_results = save_results\n        self.final_context_length_buffer = final_context_length_buffer\n        self.print_ongoing_status = print_ongoing_status\n        self.testing_results = []\n\n        self.context_lengths = np.round(np.linspace(context_lengths_min, context_lengths_max, num=context_lengths_num_intervals, endpoint=True)).astype(int)\n        if document_depth_percent_interval_type == 'linear':\n            self.document_depth_percents = np.round(np.linspace(document_depth_percent_min, document_depth_percent_max, num=document_depth_percent_intervals, endpoint=True)).astype(int)\n        elif document_depth_percent_interval_type == 'sigmoid':\n            self.document_depth_percents = [self.logistic(x) for x in np.linspace(document_depth_percent_min, document_depth_percent_max, document_depth_percent_intervals)]\n        else:\n            raise ValueError(f\"Unsupported document_depth_percent_interval_type: {document_depth_percent_interval_type}\")\n\n        self.model = Sampler()\n\n        self.enc = AutoTokenizer.from_pretrained(FLAGS.tokenizer)\n        self.enc_tiktoken = tiktoken.encoding_for_model(\"gpt-4-1106-preview\")\n\n    def generate_random_number(self, num_digits):\n        lower_bound = 10**(num_digits - 1)\n        upper_bound = 10**num_digits - 1\n        return random.randint(lower_bound, upper_bound)\n\n    def logistic(self, x, L=100, x0=50, k=.1):\n        if x == 0:\n            return 0\n        if x == 100:\n            return 100\n        return np.round(L / (1 + np.exp(-k * (x - x0))), 3)\n\n    def read_context_files(self, n):\n        max_context_length = max(self.context_lengths)\n        contexts = []\n        f = open_file(self.haystack_file, 'r')\n        for _ in range(n):\n            context = \"\"\n            toks = 0\n            while toks < max_context_length:\n                text = json.loads(f.readline())['text']\n                context += text\n                toks += len(self.enc.encode(text))\n            contexts.append(context)\n        return contexts\n\n    def encode_and_trim(self, context, context_length):\n        tokens = self.enc.encode(context)\n        if len(tokens) > context_length:\n            context = self.enc.decode(tokens[:context_length])\n        return context\n\n    def create_contexts(self, needle_rnd_number, insert_needle, random_city, trim_context, context_length, depth_percent, seed):\n        if self.save_results:\n            if self.result_exists(context_length, depth_percent):\n                return\n        needle = self.needle.format(city=random_city, rnd_number=needle_rnd_number)\n        question = self.retrieval_question.format(random_city)\n        if not insert_needle:\n            needle = \" \" #replace needle with a space\n        context = self.generate_context(needle, trim_context, context_length, depth_percent)\n        results = {\n            'context' : context,\n            'context_length' : int(context_length),\n            'depth_percent' : float(depth_percent),\n            'needle' : needle,\n            'question' : question,\n            'insert_needle' : insert_needle,\n            'needle_rnd_number' : needle_rnd_number,\n            'seed': seed,\n         }\n        return results\n\n    def insert_needle(self, needle, context, depth_percent, context_length):\n        tokens_needle = self.enc_tiktoken.encode(needle)\n        tokens_context = self.enc_tiktoken.encode(context)\n\n        # Reducing the context length by 150 buffer. This is to account for system message, the user question, and response.\n        context_length -= self.final_context_length_buffer\n\n        # If your context + needle are longer than the context length (which it will be), then reduce tokens from the context by the needle length\n        if len(tokens_context) + len(tokens_needle) > context_length:\n            tokens_context = tokens_context[:context_length - len(tokens_needle)]\n\n        if depth_percent == 100:\n            # If your depth percent is 100 (which means your needle is the last thing in the doc), throw it at the end\n            tokens_new_context = tokens_context + tokens_needle\n        else:\n            # Go get the position (in terms of tokens) to insert your needle\n            insertion_point = int(len(tokens_context) * (depth_percent / 100))\n\n            # tokens_new_context represents the tokens before the needle\n            tokens_new_context = tokens_context[:insertion_point]\n\n            # We want to make sure that we place our needle at a sentence break so we first see what token a '.' is\n            period_tokens = self.enc_tiktoken.encode('.')\n\n            # Then we iteration backwards until we find the first period\n            while tokens_new_context and tokens_new_context[-1] not in period_tokens:\n                insertion_point -= 1\n                tokens_new_context = tokens_context[:insertion_point]\n\n            # Once we get there, then add in your needle, and stick the rest of your context in on the other end.\n            # Now we have a needle in a haystack\n            tokens_new_context += tokens_needle + tokens_context[insertion_point:]\n\n        # Convert back to a string and return it\n        new_context = self.enc_tiktoken.decode(tokens_new_context)\n        return new_context\n\n    def generate_context(self, needle, trim_context, context_length, depth_percent):\n        context = self.insert_needle(needle, trim_context, depth_percent, context_length)\n        return context\n\n    def compute_max_input_length(self, context_length, buffer=1024):\n        block_size = self.model.block_size\n        context_length += buffer\n        context_length = math.ceil(context_length / block_size) * block_size\n        return int(context_length)\n\n    def run_test(self):\n        fs = gcsfs.GCSFileSystem()\n        contexts = []\n        template = self.OURS_TEMPLATE\n\n        def _key_from_result(result):\n            return (result['context_length'], result['depth_percent'], result['seed'])\n\n        results = []\n        completed = set()\n        def exists(fname):\n            if fname.startswith('gs://'):\n                return fs.exists(fname)\n            else:\n                return os.path.exists(fname)\n        if exists(FLAGS.output_file):\n            with open_file(FLAGS.output_file, 'r') as f:\n                results = json.load(f)\n                completed = set([_key_from_result(result) for result in results])\n        print('completed', len(completed))\n\n        full_contexts = self.read_context_files(FLAGS.n_rounds)\n        full_tokens = [self.enc.encode(full_context) for full_context in tqdm(full_contexts)]\n\n        start = time.time()\n        for context_length in self.context_lengths:\n            trim_contexts = [self.enc.decode(full_token[:context_length]) for full_token in tqdm(full_tokens)]\n            max_input_length = self.compute_max_input_length(context_length)\n            contexts = []\n            for depth_percent in self.document_depth_percents:\n                for i in range(FLAGS.n_rounds):\n                    if (int(context_length), float(depth_percent), i) in completed:\n                        continue\n                    random_city = random.choice(LLMNeedleHaystackTester.RANDOM_NEEDLE_CITIES)\n                    insert_needle = True\n                    needle_rnd_number = str(self.generate_random_number(self.rnd_number_digits))\n                    print(\"context length: \" + str(context_length))\n                    print(\"depth_percent : \" + str(depth_percent))\n                    context = self.create_contexts(needle_rnd_number, insert_needle, random_city, trim_contexts[i], context_length, depth_percent, i)\n                    contexts.append(context)\n\n            if len(contexts) == 0:\n                continue\n\n            B = FLAGS.max_tokens_per_batch / (max_input_length + self.model.block_size)\n            B = int(B / self.model.data_dim) * self.model.data_dim\n            if B < self.model.data_dim:\n                B = self.model.data_dim\n            elif B > len(contexts):\n                B = int(math.ceil(len(contexts) / self.model.data_dim) * self.model.data_dim)\n            if len(contexts) % B == 0:\n                n_pad = 0\n            else:\n                n_pad = B - len(contexts) % B\n            for _ in range(n_pad):\n                contexts.insert(0, contexts[0])\n\n            pbar = tqdm(total=len(contexts))\n            for i in range(0, len(contexts), B):\n                contexts_i = contexts[i:i + B]\n                prompts = [\n                    template.format(context=context['context'], question=context['question'])\n                    for context in contexts_i\n                ]\n                outs = self.model(prompts, max_input_length)\n                for j, (context, out) in enumerate(zip(contexts_i, outs)):\n                    if i + j < n_pad:\n                        continue\n                    results.append({\n                        'context_length': context['context_length'],\n                        'depth_percent': context['depth_percent'],\n                        'response': out,\n                        'answer': context['needle_rnd_number'],\n                        'correct': context['needle_rnd_number'] in out,\n                        'seed': context['seed'],\n                    })\n                    print(results[-1])\n                if jax.process_index() == 0:\n                    with open_file(FLAGS.output_file, 'w') as f:\n                        json.dump(results, f)\n                pbar.update(len(contexts_i))\n            pbar.close()\n        print('elapsed', time.time() - start)\n        print('done')\n\n\n    def print_start_test_summary(self):\n        print (\"\\n\")\n        print (\"Starting Needle In A Haystack Testing...\")\n        print (f\"- Context Lengths: {len(self.context_lengths)}, Min: {min(self.context_lengths)}, Max: {max(self.context_lengths)}\")\n        print (f\"- Document Depths: {len(self.document_depth_percents)}, Min: {min(self.document_depth_percents)}%, Max: {max(self.document_depth_percents)}%\")\n        print (f\"- Needle: {self.needle.strip()}\")\n        print (\"\\n\\n\")\n\n    def start_test(self):\n        if self.print_ongoing_status:\n            self.print_start_test_summary()\n        self.run_test()\n\n\n\nclass Sampler:\n    def __init__(self):\n        self.mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)\n        self.prefix_tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer, truncation_side='left', padding_side='left')\n        self.tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer)\n        self.sharded_rng = next_rng()\n        self._load_model()\n\n    @property\n    def block_size(self):\n        # return 2 * max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size)\n        return max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) * self.mesh.shape['sp']\n\n    @property\n    def data_dim(self):\n        return self.mesh.shape['dp'] * self.mesh.shape['fsdp']\n\n    def _load_model(self):\n        if FLAGS.load_llama_config != '':\n            llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)\n            updates = LLaMAConfig(**FLAGS.llama)\n            llama_config.update(dict(\n                scan_attention=updates.scan_attention,\n                scan_mlp=updates.scan_mlp,\n                scan_query_chunk_size=updates.scan_query_chunk_size,\n                scan_key_chunk_size=updates.scan_key_chunk_size,\n                scan_mlp_chunk_size=updates.scan_mlp_chunk_size,\n                scan_layers=updates.scan_layers,\n                param_scan_axis=updates.param_scan_axis,\n            ))\n        else:\n            llama_config = LLaMAConfig(**FLAGS.llama)\n\n        if FLAGS.update_llama_config != '':\n            llama_config.update(dict(eval(FLAGS.update_llama_config)))\n\n        llama_config.update(dict(\n            bos_token_id=self.tokenizer.bos_token_id,\n            eos_token_id=self.tokenizer.eos_token_id,\n        ))\n        llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))\n        self.config = llama_config\n\n        with jax.default_device(jax.devices(\"cpu\")[0]):\n            _, self.params = StreamingCheckpointer.load_trainstate_checkpoint(\n                    FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30\n            )\n            self.model = FlaxLLaMAForCausalLM(\n                llama_config,\n                input_shape=(512, self.block_size),\n                seed=FLAGS.seed,\n                _do_init=False,\n                dtype=get_float_dtype_by_name(FLAGS.dtype),\n            )\n            self.model_ps = match_partition_rules(\n                LLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params\n            )\n            shard_fns, _ = make_shard_and_gather_fns(\n                self.model_ps, get_float_dtype_by_name(FLAGS.dtype)\n            )\n\n            with self.mesh:\n                self.params = tree_apply(shard_fns, self.params)\n\n    @cached_property\n    def _forward_generate(self):\n        def fn(params, rng, batch):\n            batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))\n            rng_generator = JaxRNG(rng)\n            output = self.model.generate(\n                batch['input_ids'],\n                attention_mask=batch['attention_mask'],\n                params=params['params'],\n                prng_key=rng_generator(),\n                generation_config=GenerationConfig(\n                    max_new_tokens=self.block_size,\n                    pad_token_id=self.tokenizer.pad_token_id,\n                    eos_token_id=self.tokenizer.eos_token_id,\n                    temperature=0.,\n                    do_sample=False,\n                    num_beams=1,\n                    top_k=50,\n                    top_p=1.0,\n                )\n            ).sequences[:, batch['input_ids'].shape[1]:]\n            return output, rng_generator()\n        return pjit(\n            fn,\n            in_shardings=(self.model_ps, PS(), PS()),\n            out_shardings=(PS(), PS())\n        )\n\n    def __call__(self, prompts, max_input_length):\n        inputs = self.prefix_tokenizer(\n            prompts,\n            padding='max_length',\n            truncation=True,\n            max_length=max_input_length,\n            return_tensors='np'\n        )\n        batch = dict(\n            input_ids=inputs.input_ids,\n            attention_mask=inputs.attention_mask\n        )\n        with self.mesh:\n            output, self.sharded_rng = self._forward_generate(\n                self.params, self.sharded_rng, batch\n            )\n            output = jax.device_get(output)\n        output_text = []\n        for text in list(self.tokenizer.batch_decode(output, skip_special_tokens=True)):\n            if self.tokenizer.eos_token in text:\n                text = text.split(self.tokenizer.eos_token, maxsplit=1)[0]\n            output_text.append(text)\n        return output_text\n\n\ndef main(argv):\n    JaxDistributedConfig.initialize(FLAGS.jax_distributed)\n    set_random_seed(FLAGS.seed)\n\n    ht = LLMNeedleHaystackTester(\n        haystack_file=FLAGS.haystack_file,\n        context_lengths_min=FLAGS.context_lengths_min,\n        context_lengths_max=FLAGS.context_lengths_max,\n        context_lengths_num_intervals=FLAGS.n_context_length_intervals,\n        document_depth_percent_intervals=FLAGS.n_document_depth_intervals,\n    )\n    ht.start_test()\n\nif __name__ == \"__main__\":\n    run(main)\n"
  },
  {
    "path": "scripts/eval_needle_multi.py",
    "content": "from absl.app import run\nimport glob\nimport time\nimport json\nimport math\nimport os\nfrom tqdm import tqdm\nimport random\nfrom functools import cached_property\nimport numpy as np\nimport jax\nfrom jax.experimental.pjit import pjit\nfrom jax.sharding import PartitionSpec as PS\nimport gcsfs\nimport tiktoken\nfrom transformers import GenerationConfig, AutoTokenizer\nfrom tux import (\n    define_flags_with_default, StreamingCheckpointer, JaxDistributedConfig,\n    set_random_seed, get_float_dtype_by_name, JaxRNG, next_rng,\n    match_partition_rules, make_shard_and_gather_fns,\n    with_sharding_constraint, tree_apply, open_file\n)\nfrom lwm.llama import LLaMAConfig, FlaxLLaMAForCausalLM\n\n\nFLAGS, FLAGS_DEF = define_flags_with_default(\n    haystack_file=\"\",\n    max_tokens_per_batch=2000000,\n    output_file=\"results.json\",\n    context_lengths_min=1000,\n    context_lengths_max=32000,\n    n_context_length_intervals=3,\n    n_document_depth_intervals=3,\n    n_rounds=2,\n    n_needles_total=4,\n    n_needles_retrieve=4,\n    seed=1234,\n    mesh_dim='1,-1,1,1',\n    dtype='fp32',\n    load_llama_config='',\n    update_llama_config='',\n    load_checkpoint='',\n    tokenizer='LargeWorldModel/LWM-Text-1M',\n    checkpointer=StreamingCheckpointer.get_default_config(),\n    llama=LLaMAConfig.get_default_config(),\n    jax_distributed=JaxDistributedConfig.get_default_config(),\n)\n\n\nclass LLMNeedleHaystackTester:\n    OURS_TEMPLATE = \"You are a helpful assistant. USER: {context} {question} Don't give information outside the document. ASSISTANT: \"\n    RANDOM_NEEDLE_CITIES  = [\n        'Chicago', 'Yangon', 'Antananarivo', 'Colombo', 'Almaty', 'Sydney', 'Chicago', 'Mexico City',\n        'Seattle', 'Lagos', 'Amsterdam', 'Belgrade', 'Cairo', 'Baghdad', 'Damascus', 'Kigali', 'Dakar',\n        'Dakar', 'Sofia', 'Kigali', 'Victoria', 'Tashkent', 'Mumbai', 'Barcelona', 'Almaty', 'Amman',\n        'Toronto', 'Bratislava', 'Johannesburg', 'Thimphu', 'Bangkok', 'Santiago', 'Cairo', 'San Francisco',\n        'Lagos', 'Amsterdam', 'Paris', 'Rabat', 'Santiago', 'Copenhagen', 'Madrid', 'Kigali',\n        'Ho Chi Minh City', 'Sarajevo', 'Delhi', 'Istanbul', 'Ho Chi Minh City', 'Khartoum', 'Helsinki',\n        'Doha', 'Istanbul', 'Kuala Lumpur', 'Budapest', 'Shanghai', 'Moscow', 'Los Angeles', 'Oslo',\n        'Johannesburg', 'Berlin', 'Bangalore', 'Tokyo', 'Melbourne', 'Barcelona', 'Chicago', 'Port Louis',\n        'Lisbon', 'Nairobi', 'Kampala', 'Lima', 'Maputo', 'Vancouver', 'Dubai', 'Khartoum', 'Jakarta',\n        'Madrid', 'Yerevan', 'Beirut', 'Athens', 'Chicago', 'Paris', 'Bucharest', 'Copenhagen', 'Brussels',\n        'Damascus', 'Seattle', 'Los Angeles', 'Yerevan', 'Victoria', 'Tunis', 'Astana', 'Seoul',\n        'Buenos Aires', 'Bangkok', 'Colombo', 'Brussels', 'Khartoum', 'Doha', 'San Francisco', 'Vienna', 'Jakarta'\n    ]\n\n    def __init__(self,\n                 needle=\"\",\n                 haystack_file=\"\",\n                 retrieval_question=\"What are the special magic numbers for {}?\",\n                 results_version = 1,\n                 rnd_number_digits = 7,\n                 context_lengths_min = 1000,\n                 context_lengths_max = 126000,\n                 context_lengths_num_intervals = 10,\n                 document_depth_percent_min = 0,\n                 document_depth_percent_max = 100,\n                 document_depth_percent_intervals = 10,\n                 document_depth_percent_interval_type = \"linear\",\n                 save_results = False,\n                 final_context_length_buffer = 200,\n                 print_ongoing_status = True):\n        needle=\"\\nThe special magic {city} number is: {rnd_number}\\n\"\n        self.needle = needle\n        if not needle or not haystack_file or not retrieval_question:\n            raise ValueError(\"Needle, haystack, and retrieval_question must be provided.\")\n\n        self.rnd_number_digits = rnd_number_digits\n        self.context_lengths_num_intervals = context_lengths_num_intervals\n        self.document_depth_percent_intervals = document_depth_percent_intervals\n        self.haystack_file = haystack_file\n        self.retrieval_question = retrieval_question\n        self.results_version = results_version\n        self.save_results = save_results\n        self.final_context_length_buffer = final_context_length_buffer\n        self.print_ongoing_status = print_ongoing_status\n        self.testing_results = []\n\n        self.context_lengths = np.round(np.linspace(context_lengths_min, context_lengths_max, num=context_lengths_num_intervals, endpoint=True)).astype(int)\n        self.context_lengths = self.context_lengths.tolist()\n        if document_depth_percent_interval_type == 'linear':\n            self.document_depth_percents = np.round(np.linspace(document_depth_percent_min, document_depth_percent_max, num=document_depth_percent_intervals, endpoint=True)).astype(int)\n        elif document_depth_percent_interval_type == 'sigmoid':\n            self.document_depth_percents = [self.logistic(x) for x in np.linspace(document_depth_percent_min, document_depth_percent_max, document_depth_percent_intervals)]\n        else:\n            raise ValueError(f\"Unsupported document_depth_percent_interval_type: {document_depth_percent_interval_type}\")\n        self.document_depth_percents = self.document_depth_percents.tolist()\n\n        self.model = Sampler()\n\n        self.enc = AutoTokenizer.from_pretrained(FLAGS.tokenizer)\n        self.enc_tiktoken = tiktoken.encoding_for_model(\"gpt-4-1106-preview\")\n\n    def generate_random_number(self, num_digits):\n        lower_bound = 10**(num_digits - 1)\n        upper_bound = 10**num_digits - 1\n        return random.randint(lower_bound, upper_bound)\n\n    def logistic(self, x, L=100, x0=50, k=.1):\n        if x == 0:\n            return 0\n        if x == 100:\n            return 100\n        return np.round(L / (1 + np.exp(-k * (x - x0))), 3)\n\n    def read_context_files(self, n):\n        max_context_length = max(self.context_lengths)\n        contexts = []\n        f = open_file(self.haystack_file, 'r')\n        for i in range(n):\n            context = \"\"\n            while len(self.enc.encode(context)) < max_context_length:\n                context += json.loads(f.readline())['text']\n            contexts.append(context)\n        return contexts\n\n    def encode_and_trim(self, context, context_length):\n        tokens = self.enc.encode(context)\n        if len(tokens) > context_length:\n            context = self.enc.decode(tokens[:context_length])\n        return context\n\n    def create_contexts(self, needles_info, random_cities_retrieve, context, context_length, seed):\n        assert all([random_city in needles_info for random_city in random_cities_retrieve])\n        for random_city, (needle_rnd_number, depth_percent) in needles_info.items():\n            context = self.generate_context(\n                self.needle.format(city=random_city, rnd_number=needle_rnd_number),\n                context, context_length, depth_percent\n            )\n\n        if len(random_cities_retrieve) == 1:\n            question = f\"What is the special magic number for {random_cities_retrieve[0]}?\"\n        else:\n            q = ', '.join(random_cities_retrieve[:-1]) + ', and ' + random_cities_retrieve[-1]\n            question = self.retrieval_question.format(q)\n        results = {\n            'context' : context,\n            'context_length' : int(context_length),\n            'needles_info': needles_info,\n            'question' : question,\n            'cities_to_retrieve' : random_cities_retrieve,\n            'seed': seed,\n         }\n        return results\n\n    def insert_needle(self, needle, context, depth_percent, context_length):\n        tokens_needle = self.enc_tiktoken.encode(needle)\n        tokens_context = self.enc_tiktoken.encode(context)\n\n        # Reducing the context length by 150 buffer. This is to account for system message, the user question, and response.\n        context_length -= self.final_context_length_buffer\n\n        # If your context + needle are longer than the context length (which it will be), then reduce tokens from the context by the needle length\n        if len(tokens_context) + len(tokens_needle) > context_length:\n            tokens_context = tokens_context[:context_length - len(tokens_needle)]\n\n        if depth_percent == 100:\n            # If your depth percent is 100 (which means your needle is the last thing in the doc), throw it at the end\n            tokens_new_context = tokens_context + tokens_needle\n        else:\n            # Go get the position (in terms of tokens) to insert your needle\n            insertion_point = int(len(tokens_context) * (depth_percent / 100))\n\n            # tokens_new_context represents the tokens before the needle\n            tokens_new_context = tokens_context[:insertion_point]\n\n            # We want to make sure that we place our needle at a sentence break so we first see what token a '.' is\n            period_tokens = self.enc_tiktoken.encode('.')\n\n            # Then we iteration backwards until we find the first period\n            while tokens_new_context and tokens_new_context[-1] not in period_tokens:\n                insertion_point -= 1\n                tokens_new_context = tokens_context[:insertion_point]\n\n            # Once we get there, then add in your needle, and stick the rest of your context in on the other end.\n            # Now we have a needle in a haystack\n            tokens_new_context += tokens_needle + tokens_context[insertion_point:]\n\n        # Convert back to a string and return it\n        new_context = self.enc_tiktoken.decode(tokens_new_context)\n        return new_context\n\n    def generate_context(self, needle, trim_context, context_length, depth_percent):\n        context = self.insert_needle(needle, trim_context, depth_percent, context_length)\n        return context\n\n    def compute_max_input_length(self, context_length, buffer=1024):\n        block_size = self.model.block_size\n        context_length += buffer\n        # context_length = 2 ** math.ceil(math.log2(context_length))\n        context_length = math.ceil(context_length / block_size) * block_size\n        return int(context_length)\n\n    def run_test(self):\n        fs = gcsfs.GCSFileSystem()\n        contexts = []\n        template = self.OURS_TEMPLATE\n\n        def _key_from_result(result):\n            return (result['context_length'], result['depth_percent'], result['seed'])\n\n        results = []\n        completed = set()\n        def exists(fname):\n            if fname.startswith('gs://'):\n                return fs.exists(fname)\n            else:\n                return os.path.exists(fname)\n        if exists(FLAGS.output_file):\n            with open_file(FLAGS.output_file, 'r') as f:\n                results = json.load(f)\n                completed = set([_key_from_result(result) for result in results])\n        print('completed', len(completed))\n\n        full_contexts = self.read_context_files(FLAGS.n_rounds)\n        full_tokens = [self.enc.encode(full_context) for full_context in full_contexts]\n\n        start = time.time()\n        for context_length in self.context_lengths:\n            trim_contexts = [self.enc.decode(full_token[:context_length]) for full_token in full_tokens]\n            max_input_length = self.compute_max_input_length(context_length)\n            contexts = []\n            for i in range(FLAGS.n_rounds):\n                if (int(context_length), i) in completed:\n                    continue\n                random_cities = random.sample(LLMNeedleHaystackTester.RANDOM_NEEDLE_CITIES, FLAGS.n_needles_total)\n                document_depths = random.sample(self.document_depth_percents, FLAGS.n_needles_total)\n                random_cities_retrieve = random.sample(random_cities, FLAGS.n_needles_retrieve)\n                needles_info = {}\n                for random_city, depth_percent in zip(random_cities, document_depths):\n                    needles_info[random_city] = (\n                        str(self.generate_random_number(self.rnd_number_digits)),\n                        depth_percent\n                    )\n                context = self.create_contexts(needles_info, random_cities_retrieve, trim_contexts[i], context_length, i)\n                contexts.append(context)\n\n            if len(contexts) == 0:\n                continue\n\n            B = FLAGS.max_tokens_per_batch / (max_input_length + self.model.block_size)\n            B = int(B / self.model.data_dim) * self.model.data_dim\n            if B < self.model.data_dim:\n                B = self.model.data_dim\n            elif B > len(contexts):\n                B = int(math.ceil(len(contexts) / self.model.data_dim) * self.model.data_dim)\n            n_pad = B - len(contexts) % B\n            for _ in range(n_pad):\n                contexts.insert(0, contexts[0])\n\n            pbar = tqdm(total=len(contexts))\n            for i in range(0, len(contexts), B):\n                contexts_i = contexts[i:i + B]\n                prompts = [\n                    template.format(context=context['context'], question=context['question'])\n                    for context in contexts_i\n                ]\n                outs = self.model(prompts, max_input_length)\n                for j, (context, out) in enumerate(zip(contexts_i, outs)):\n                    if i + j < n_pad:\n                        continue\n                    rnd_nums_to_retrieve = [\n                        context['needles_info'][city][0] for city in context['cities_to_retrieve']\n                    ]\n                    results.append({\n                        'context_length': context['context_length'],\n                        'needles_info': context['needles_info'],\n                        'question': context['question'],\n                        'answer': rnd_nums_to_retrieve,\n                        'response': out,\n                        'correct': [rnd_num in out for rnd_num in rnd_nums_to_retrieve],\n                        'seed': context['seed'],\n                    })\n                    print(results[-1]['correct'], out, rnd_nums_to_retrieve)\n                if jax.process_index() == 0:\n                    with open_file(FLAGS.output_file, 'w') as f:\n                        json.dump(results, f)\n                pbar.update(len(contexts_i))\n            pbar.close()\n        print('elapsed', time.time() - start)\n        print('done')\n\n\n    def print_start_test_summary(self):\n        print (\"\\n\")\n        print (\"Starting Needle In A Haystack Testing...\")\n        print (f\"- Context Lengths: {len(self.context_lengths)}, Min: {min(self.context_lengths)}, Max: {max(self.context_lengths)}\")\n        print (f\"- Document Depths: {len(self.document_depth_percents)}, Min: {min(self.document_depth_percents)}%, Max: {max(self.document_depth_percents)}%\")\n        print (f\"- Needle: {self.needle.strip()}\")\n        print (\"\\n\\n\")\n\n    def start_test(self):\n        if self.print_ongoing_status:\n            self.print_start_test_summary()\n        self.run_test()\n\n\n\nclass Sampler:\n    def __init__(self):\n        self.mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)\n        self.prefix_tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer, truncation_side='left', padding_side='left')\n        self.tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer)\n        self.sharded_rng = next_rng()\n        self._load_model()\n\n    @property\n    def block_size(self):\n        # return 2 * max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size)\n        return max(self.config.scan_query_chunk_size, self.config.scan_key_chunk_size) * self.mesh.shape['sp']\n\n    @property\n    def data_dim(self):\n        return self.mesh.shape['dp'] * self.mesh.shape['fsdp']\n\n    def _load_model(self):\n        if FLAGS.load_llama_config != '':\n            llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)\n            updates = LLaMAConfig(**FLAGS.llama)\n            llama_config.update(dict(\n                scan_attention=updates.scan_attention,\n                scan_mlp=updates.scan_mlp,\n                scan_query_chunk_size=updates.scan_query_chunk_size,\n                scan_key_chunk_size=updates.scan_key_chunk_size,\n                scan_mlp_chunk_size=updates.scan_mlp_chunk_size,\n                scan_layers=updates.scan_layers,\n                param_scan_axis=updates.param_scan_axis,\n            ))\n        else:\n            llama_config = LLaMAConfig(**FLAGS.llama)\n\n        if FLAGS.update_llama_config != '':\n            llama_config.update(dict(eval(FLAGS.update_llama_config)))\n\n        llama_config.update(dict(\n            bos_token_id=self.tokenizer.bos_token_id,\n            eos_token_id=self.tokenizer.eos_token_id,\n        ))\n        llama_config.update(dict(mesh_dim=FLAGS.mesh_dim))\n        self.config = llama_config\n\n        with jax.default_device(jax.devices(\"cpu\")[0]):\n            _, self.params = StreamingCheckpointer.load_trainstate_checkpoint(\n                    FLAGS.load_checkpoint, disallow_trainstate=True, max_buffer_size=32 * 2 ** 30\n            )\n            self.model = FlaxLLaMAForCausalLM(\n                llama_config,\n                input_shape=(512, self.block_size),\n                seed=FLAGS.seed,\n                _do_init=False,\n                dtype=get_float_dtype_by_name(FLAGS.dtype),\n            )\n            self.model_ps = match_partition_rules(\n                LLaMAConfig.get_partition_rules(llama_config.scan_layers, llama_config.param_scan_axis), self.params\n            )\n            shard_fns, _ = make_shard_and_gather_fns(\n                self.model_ps, get_float_dtype_by_name(FLAGS.dtype)\n            )\n\n            with self.mesh:\n                self.params = tree_apply(shard_fns, self.params)\n\n    @cached_property\n    def _forward_generate(self):\n        def fn(params, rng, batch):\n            batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'), 'sp'))\n            rng_generator = JaxRNG(rng)\n            output = self.model.generate(\n                batch['input_ids'],\n                attention_mask=batch['attention_mask'],\n                params=params['params'],\n                prng_key=rng_generator(),\n                generation_config=GenerationConfig(\n                    max_new_tokens=self.block_size,\n                    pad_token_id=self.tokenizer.pad_token_id,\n                    eos_token_id=self.tokenizer.eos_token_id,\n                    temperature=0.,\n                    do_sample=False,\n                    num_beams=1,\n                    top_k=50,\n                    top_p=1.0,\n                )\n            ).sequences[:, batch['input_ids'].shape[1]:]\n            return output, rng_generator()\n        return pjit(\n            fn,\n            in_shardings=(self.model_ps, PS(), PS()),\n            out_shardings=(PS(), PS())\n        )\n\n    def __call__(self, prompts, max_input_length):\n        inputs = self.prefix_tokenizer(\n            prompts,\n            padding='max_length',\n            truncation=True,\n            max_length=max_input_length,\n            return_tensors='np'\n        )\n        batch = dict(\n            input_ids=inputs.input_ids,\n            attention_mask=inputs.attention_mask\n        )\n        with self.mesh:\n            output, self.sharded_rng = self._forward_generate(\n                self.params, self.sharded_rng, batch\n            )\n            output = jax.device_get(output)\n        output_text = []\n        for text in list(self.tokenizer.batch_decode(output, skip_special_tokens=True)):\n            if self.tokenizer.eos_token in text:\n                text = text.split(self.tokenizer.eos_token, maxsplit=1)[0]\n            output_text.append(text)\n        return output_text\n\n\ndef main(argv):\n    JaxDistributedConfig.initialize(FLAGS.jax_distributed)\n    set_random_seed(FLAGS.seed)\n\n    ht = LLMNeedleHaystackTester(\n        haystack_file=FLAGS.haystack_file,\n        context_lengths_min=FLAGS.context_lengths_min,\n        context_lengths_max=FLAGS.context_lengths_max,\n        context_lengths_num_intervals=FLAGS.n_context_length_intervals,\n        document_depth_percent_intervals=FLAGS.n_document_depth_intervals,\n    )\n    ht.start_test()\n\nif __name__ == \"__main__\":\n    run(main)\n"
  },
  {
    "path": "scripts/run_eval_needle.sh",
    "content": "#! /bin/bash\n\nexport SCRIPT_DIR=\"$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nexport PROJECT_DIR=\"$( cd -- \"$( dirname -- \"$SCRIPT_DIR\" )\" &> /dev/null && pwd )\"\ncd $PROJECT_DIR\nexport PYTHONPATH=\"$PYTHONPATH:$PROJECT_DIR\"\n\nexport llama_tokenizer_path=\"LargeWorldModel/LWM-Text-1M\"\nexport lwm_text_checkpoint=\"\"\n# jsonl file containing text for haystack. Each line should be a json\n# with a single key \"text\" containing the text.\nexport haystack_file=\"\"\nexport output_file=\"\"\n\npython3 -u scripts/eval_needle.py \\\n    --mesh_dim='!1,-1,4,1' \\\n    --dtype='fp32' \\\n    --load_llama_config='7b' \\\n    --update_llama_config=\"dict(theta=10000000,max_sequence_length=131072,scan_attention=True,scan_query_chunk_size=1024,scan_key_chunk_size=1024,scan_mlp=True,scan_mlp_chunk_size=1024,scan_layers=True)\" \\\n    --load_checkpoint=\"params::$lwm_text_checkpoint\" \\\n    --tokenizer=\"$llama_tokenizer_path\" \\\n    --max_tokens_per_batch=5000 \\\n    --output_file=\"$output_file\" \\\n    --haystack_file=\"$haystack_file\" \\\n    --context_lengths_min=1000 \\\n    --context_lengths_max=10000 \\\n    --n_context_length_intervals=20 \\\n    --n_document_depth_intervals=20 \\\n    --n_rounds=3\nread\n"
  },
  {
    "path": "scripts/run_eval_needle_multi.sh",
    "content": "#! /bin/bash\n\nexport SCRIPT_DIR=\"$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nexport PROJECT_DIR=\"$( cd -- \"$( dirname -- \"$SCRIPT_DIR\" )\" &> /dev/null && pwd )\"\ncd $PROJECT_DIR\nexport PYTHONPATH=\"$PYTHONPATH:$PROJECT_DIR\"\n\nexport llama_tokenizer_path=\"LargeWorldModel/LWM-Text-1M\"\nexport lwm_text_checkpoint=\"\"\n# jsonl file containing text for haystack. Each line should be a json\n# with a single key \"text\" containing the text.\nexport haystack_file=\"\"\nexport output_file=\"\"\n\npython3 -u scripts/eval_needle_multi.py \\\n    --mesh_dim='!1,1,-1,1' \\\n    --dtype='fp32' \\\n    --load_llama_config='7b' \\\n    --update_llama_config=\"dict(theta=10000000,max_sequence_length=131072,scan_attention=True,scan_query_chunk_size=1024,scan_key_chunk_size=1024,scan_mlp=True,scan_mlp_chunk_size=1024,scan_layers=True)\" \\\n    --load_checkpoint=\"params::$lwm_text_checkpoint\" \\\n    --tokenizer=\"$llama_tokenizer_path\" \\\n    --max_tokens_per_batch=5000 \\\n    --output_file=\"$output_file\" \\\n    --haystack_file=\"$haystack_file\" \\\n    --context_lengths_min=1000 \\\n    --context_lengths_max=10000 \\\n    --n_context_length_intervals=10 \\\n    --n_document_depth_intervals=10 \\\n    --n_needles_total=4 \\\n    --n_needles_retrieve=2 \\\n    --n_rounds=10\nread\n"
  },
  {
    "path": "scripts/run_sample_image.sh",
    "content": "#! /bin/bash\n\nexport SCRIPT_DIR=\"$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nexport PROJECT_DIR=\"$( cd -- \"$( dirname -- \"$SCRIPT_DIR\" )\" &> /dev/null && pwd )\"\ncd $PROJECT_DIR\nexport PYTHONPATH=\"$PYTHONPATH:$PROJECT_DIR\"\n\nexport llama_tokenizer_path=\"LargeWorldModel/LWM-Text-1M\"\nexport vqgan_checkpoint=\"\"\nexport lwm_checkpoint=\"\"\n\n# Relevant params\n# --temperature_*: Temperature that is applied to each of the logits\n# --top_k_*: Only sample from the tokens with the top k logits\n# --cfg_scale_*: Classifier-free guidance scale for each modality\n# --n_frames: Number of frames to generate. For images specify 1.\n\npython3 -u -m lwm.vision_generation \\\n    --prompt='Fireworks over the city' \\\n    --output_file='fireworks.png' \\\n    --temperature_image=1.0 \\\n    --top_k_image=8192 \\\n    --cfg_scale_image=5.0 \\\n    --vqgan_checkpoint=\"$vqgan_checkpoint\" \\\n    --n_frames=1 \\\n    --mesh_dim='!1,1,-1,1' \\\n    --dtype='fp32' \\\n    --load_llama_config='7b' \\\n    --update_llama_config=\"dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)\" \\\n    --load_checkpoint=\"params::$lwm_checkpoint\" \\\n    --tokenizer=\"$llama_tokenizer_path\"\nread\n"
  },
  {
    "path": "scripts/run_sample_video.sh",
    "content": "#! /bin/bash\n\nexport SCRIPT_DIR=\"$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nexport PROJECT_DIR=\"$( cd -- \"$( dirname -- \"$SCRIPT_DIR\" )\" &> /dev/null && pwd )\"\ncd $PROJECT_DIR\nexport PYTHONPATH=\"$PYTHONPATH:$PROJECT_DIR\"\n\nexport llama_tokenizer_path=\"LargeWorldModel/LWM-Text-1M\"\nexport vqgan_checkpoint=\"\"\nexport lwm_checkpoint=\"\"\n\n# Relevant params\n# --temperature_*: Temperature that is applied to each of the logits\n# --top_k_*: Only sample from the tokens with the top k logits\n# --cfg_scale_*: Classifier-free guidance scale for each modality\n# --n_frames: Number of frames to generate\n\npython3 -u -m lwm.vision_generation \\\n    --prompt='Fireworks over the city' \\\n    --output_file='fireworks.mp4' \\\n    --temperature_image=1.0 \\\n    --temperature_video=1.0 \\\n    --top_k_image=8192 \\\n    --top_k_video=1000 \\\n    --cfg_scale_image=5.0 \\\n    --cfg_scale_video=1.0 \\\n    --vqgan_checkpoint=\"$vqgan_checkpoint\" \\\n    --n_frames=8 \\\n    --mesh_dim='!1,1,-1,1' \\\n    --dtype='fp32' \\\n    --load_llama_config='7b' \\\n    --update_llama_config=\"dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)\" \\\n    --load_checkpoint=\"params::$lwm_checkpoint\" \\\n    --tokenizer=\"$llama_tokenizer_path\"\nread\n"
  },
  {
    "path": "scripts/run_train_text.sh",
    "content": "#! /bin/bash\n\nexport SCRIPT_DIR=\"$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nexport PROJECT_DIR=\"$( cd -- \"$( dirname -- \"$SCRIPT_DIR\" )\" &> /dev/null && pwd )\"\ncd $PROJECT_DIR\nexport PYTHONPATH=\"$PYTHONPATH:$PROJECT_DIR\"\nexport LIBTPU_INIT_ARGS=\"--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true\"\n\nexport llama_tokenizer_path=\"LargeWorldModel/LWM-Text-1M\"\nexport dataset_path=\"\"\nexport output_dir=\"\"\n\nexport project_id='lwm'\nexport experiment_note=''\nexport experiment_id='example-text-train'\n\n# mesh_dim: dp, fsdp, tp, sp\npython3 -u -m lwm.train \\\n    --modality='text' \\\n    --mesh_dim='!1,-1,2,2' \\\n    --dtype='fp32' \\\n    --total_steps=200 \\\n    --log_freq=1 \\\n    --save_model_freq=0 \\\n    --save_milestone_freq=10 \\\n    --load_llama_config='debug' \\\n    --update_llama_config=\"dict(theta=10000,max_sequence_length=2048,scan_attention=True,scan_query_chunk_size=256,scan_key_chunk_size=512,scan_mlp=True,scan_mlp_chunk_size=1024,scan_layers=True)\" \\\n    --tokenizer=\"$llama_tokenizer_path\" \\\n    --optimizer.type='adamw' \\\n    --optimizer.accumulate_gradient_steps=1 \\\n    --optimizer.adamw_optimizer.weight_decay=0.1 \\\n    --optimizer.adamw_optimizer.lr=8e-5 \\\n    --optimizer.adamw_optimizer.end_lr=8e-5 \\\n    --optimizer.adamw_optimizer.lr_warmup_steps=5 \\\n    --optimizer.adamw_optimizer.lr_decay_steps=200 \\\n    --use_data_sharded_loader=True \\\n    --train_dataset.type='json' \\\n    --train_dataset.text_processor.fields='text' \\\n    --train_dataset.json_dataset.path=\"$dataset_path\" \\\n    --train_dataset.json_dataset.seq_length=2048 \\\n    --train_dataset.json_dataset.batch_size=1024 \\\n    --train_dataset.json_dataset.tokenizer_processes=16 \\\n    --train_dataset.json_dataset.use_data_sharded_loader=True \\\n    --checkpointer.save_optimizer_state=True \\\n    --autoresume=False \\\n    --logger.append_uuid=False \\\n    --logger.online=False \\\n    --logger.project_id=\"$project_id\" \\\n    --logger.experiment_id=\"$experiment_id\" \\\n    --logger.experiment_note=\"$experiment_note\" \\\n    --logger.output_dir=\"$output_dir\" \\\n    --logger.wandb_dir=\"$HOME/experiment_output/$project_id\"\nread\n"
  },
  {
    "path": "scripts/run_train_vision_text.sh",
    "content": "#! /bin/bash\n\nexport SCRIPT_DIR=\"$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nexport PROJECT_DIR=\"$( cd -- \"$( dirname -- \"$SCRIPT_DIR\" )\" &> /dev/null && pwd )\"\ncd $PROJECT_DIR\nexport PYTHONPATH=\"$PYTHONPATH:$PROJECT_DIR\"\nexport LIBTPU_INIT_ARGS=\"--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true\"\n\nexport llama_tokenizer_path=\"LargeWorldModel/LWM-Text-1M\"\nexport dataset_path=\"\"\nexport output_dir=\"\"\n\nexport project_id='lwm'\nexport experiment_note=''\nexport experiment_id='example-vision-text-train'\n\n# mesh_dim: dp, fsdp, tp, sp\npython3 -u -m lwm.train \\\n    --modality='vision,text' \\\n    --mesh_dim='!1,-1,2,2' \\\n    --dtype='fp32' \\\n    --total_steps=200 \\\n    --log_freq=1 \\\n    --save_model_freq=0 \\\n    --save_milestone_freq=10 \\\n    --load_llama_config='debug' \\\n    --update_llama_config=\"dict(theta=50000000,max_sequence_length=2048,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,scan_mlp=True,scan_mlp_chunk_size=8192,scan_layers=True)\" \\\n    --tokenizer=\"$llama_tokenizer_path\" \\\n    --optimizer.type='adamw' \\\n    --optimizer.accumulate_gradient_steps=1 \\\n    --optimizer.adamw_optimizer.weight_decay=0.1 \\\n    --optimizer.adamw_optimizer.lr=8e-5 \\\n    --optimizer.adamw_optimizer.end_lr=8e-5 \\\n    --optimizer.adamw_optimizer.lr_warmup_steps=5 \\\n    --optimizer.adamw_optimizer.lr_decay_steps=200 \\\n    --use_data_sharded_loader=True \\\n    --train_dataset.type='json_vision' \\\n    --train_dataset.vision_text_processor.fields_from_example='fields' \\\n    --train_dataset.vision_text_processor.max_n_frames=4 \\\n    --train_dataset.json_vision_dataset.mode=\"no_pad\" \\\n    --train_dataset.json_vision_dataset.path=\"$dataset_path\" \\\n    --train_dataset.json_vision_dataset.seq_length=2048 \\\n    --train_dataset.json_vision_dataset.batch_size=8 \\\n    --train_dataset.json_vision_dataset.tokenizer_processes=4 \\\n    --train_dataset.json_vision_dataset.tokenizer_parallel_chunk_size=2 \\\n    --train_dataset.json_vision_dataset.tokenizer_parallel_batch_size=8 \\\n    --train_dataset.json_vision_dataset.use_data_sharded_loader=True \\\n    --checkpointer.save_optimizer_state=True \\\n    --autoresume=False \\\n    --logger.append_uuid=False \\\n    --logger.online=False \\\n    --logger.project_id=\"$project_id\" \\\n    --logger.experiment_id=\"$experiment_id\" \\\n    --logger.experiment_note=\"$experiment_note\" \\\n    --logger.output_dir=\"$output_dir\" \\\n    --logger.wandb_dir=\"$HOME/experiment_output/$project_id\"\nread\n"
  },
  {
    "path": "scripts/run_vision_chat.sh",
    "content": "#! /bin/bash\n\nexport SCRIPT_DIR=\"$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nexport PROJECT_DIR=\"$( cd -- \"$( dirname -- \"$SCRIPT_DIR\" )\" &> /dev/null && pwd )\"\ncd $PROJECT_DIR\nexport PYTHONPATH=\"$PYTHONPATH:$PROJECT_DIR\"\n\nexport llama_tokenizer_path=\"LargeWorldModel/LWM-Text-1M\"\nexport vqgan_checkpoint=\"\"\nexport lwm_checkpoint=\"\"\nexport input_file=\"\"\n\n# Relevant params\n# --input_file: A given image file (png or jpg) or video file (any video format support by decord, e.g. mp4)\n# --max_n_frames: Maximum number of frames to process. If the video is longer than max_n_frames frames, it uniformly samples max_n_frames frames from the video\n\npython3 -u -m lwm.vision_chat \\\n    --prompt=\"What is the video about?\" \\\n    --input_file=\"$input_file\" \\\n    --vqgan_checkpoint=\"$vqgan_checkpoint\" \\\n    --mesh_dim='!1,1,-1,1' \\\n    --dtype='fp32' \\\n    --load_llama_config='7b' \\\n    --max_n_frames=8 \\\n    --update_llama_config=\"dict(sample_mode='text',theta=50000000,max_sequence_length=131072,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=2048,scan_layers=True)\" \\\n    --load_checkpoint=\"params::$lwm_checkpoint\" \\\n    --tokenizer=\"$llama_tokenizer_path\" \\\n2>&1 | tee ~/output.log\nread\n"
  },
  {
    "path": "scripts/sample_pyt.py",
    "content": "import argparse\nfrom transformers import LlamaForCausalLM, LlamaTokenizer\n\nparser = argparse.ArgumentParser()\nparser.add_argument('-m', '--model', type=str, default='LargeWorldModel/LWM-Text-Chat-256K')\nargs = parser.parse_args()\n\nmodel = LlamaForCausalLM.from_pretrained(args.model)\ntokenizer = LlamaTokenizer.from_pretrained(args.model)\n\n# template only relevant for chat models. non-chat models do not need this\nTEMPLATE = \"You are a helpful assistant. USER: {} ASSISTANT:\"\nquestion = \"What is the capital of France?\"\nprompt = TEMPLATE.format(question)\ninputs = tokenizer(prompt, return_tensors=\"pt\")\n\ngenerate_ids = model.generate(inputs.input_ids, max_length=300)\noutput = tokenizer.batch_decode(generate_ids, skip_special_tokens=True)[0]\n\nprint(output)"
  },
  {
    "path": "tpu_requirements.sh",
    "content": "#! /bin/bash\n\nsudo apt-get update && sudo apt-get install -y \\\n    build-essential \\\n    python-is-python3 \\\n    tmux \\\n    htop \\\n    git \\\n    ffmpeg\n\n# Update pip\npip install --upgrade pip\n\n# Python dependencies\ncat > $HOME/tpu_requirements.txt <<- EndOfFile\n-f https://storage.googleapis.com/jax-releases/libtpu_releases.html\njax[tpu]==0.4.29\nflax==0.8.4\noptax==0.2.2\n--extra-index-url https://download.pytorch.org/whl/cpu\ntorch==2.0.0\ntransformers==4.40.0\nringattention @ git+https://github.com/haoliuhl/ringattention.git\ndatasets\neinops\ntqdm\nml_collections\nwandb\ngcsfs\nrequests\ntyping-extensions\nsentencepiece\ntux @ git+https://github.com/haoliuhl/tux.git\nPillow\nffmpeg-python\nipdb\nimageio[ffmpeg]\nopencv-python\ndecord\nffmpeg-python\nh5py\npsutil\nEndOfFile\n\npip install --upgrade -r $HOME/tpu_requirements.txt\n\n# vim configurations\ncat > $HOME/.vimrc <<- EndOfFile\nset tabstop=4\nset shiftwidth=4\nset softtabstop=4\nset expandtab\nset backspace=indent,eol,start\nsyntax on\nEndOfFile\n\n# tmux configurations\ncat > $HOME/.tmux.conf <<- EndOfFile\nbind r source-file ~/.tmux.conf \\; display-message \"█▓░ ~/.tmux.conf reloaded.\"\n\n# Enable colors, https://github.com/tmux/tmux/wiki/FAQ\nset -g default-terminal \"tmux-256color\"\n\n# start with window 1 (instead of 0)\nset -g base-index 1\nsetw -g pane-base-index 1\n\nset -g prefix C-a\n\nset -g set-titles on\nset -g set-titles-string '#(whoami)::#h::#(curl ipecho.net/plain;echo)'\n\n# Status bar customization\nset -g status-interval 5\nset -g status-left-length 90\nset -g status-right-length 60\nset -g status-justify left\n\n# send the prefix to client inside window (ala nested sessions)\nbind-key a send-prefix\n\nbind-key x kill-pane\n\n# auto reorder\nset-option -g renumber-windows on\n\n# default window name\nset -g status-left \"#[fg=green,bg=colour236] #S \"\n\n# default statusbar colors\nset-option -g status-style fg=yellow,dim,bg=colour235\n\n# default window title colors\nset-window-option -g window-status-style fg=yellow,bg=colour236,dim\n\n# active window title colors\nset-window-option -g window-status-current-style fg=brightred,bg=colour236\n\n# basename as window title https://stackoverflow.com/a/37136828\nset-window-option -g window-status-current-format '#{window_index} #{pane_current_command} #(echo \"#{pane_current_path}\" | rev | cut -d'/' -f-3 | rev)'\nset-window-option -g window-status-format '#{window_index} #{pane_current_command} #(echo \"#{pane_current_path}\" | rev | cut -d'/' -f-3 | rev)'\n\n# pane border\nset-option -g pane-border-style fg=white #base2\nset-option -g pane-active-border-style fg=brightcyan #base1\n\n# enable mouse click\nset -g mouse on\n\n# keep window on\nset -g remain-on-exit on\n\n# Longer scrollback history\nset -g history-limit 50000\n\n# Scroll position indicator\nset -g mode-style bg=colour235,fg=colour245\n\n# SSH agent forwarding\n# set-environment -g SSH_AUTH_SOCK $SSH_AUTH_SOCK\nif-shell '[ -n $SSH_AUTH_SOCK ]' \" \\\n  set-option -sg update-environment \\\"DISPLAY WINDOWID XAUTHORITY\\\"; \\\n  setenv -g SSH_AUTH_SOCK /tmp/ssh_auth_sock_tmux; \\\n  run-shell \\\"ln -sf $(find /tmp/ssh-* -type s -readable | head -n 1) /tmp/ssh_auth_sock_tmux\\\" \\\n\"\n\n# Drag windows on the status bar\nbind-key -n MouseDrag1Status swap-window -t=\nEndOfFile\n\n\n# htop Configurations\nmkdir -p $HOME/.config/htop\ncat > $HOME/.config/htop/htoprc <<- EndOfFile\n# Beware! This file is rewritten by htop when settings are changed in the interface.\n# The parser is also very primitive, and not human-friendly.\nfields=0 48 17 18 38 39 40 2 46 47 49 1\nsort_key=46\nsort_direction=1\nhide_threads=0\nhide_kernel_threads=1\nhide_userland_threads=1\nshadow_other_users=0\nshow_thread_names=0\nshow_program_path=1\nhighlight_base_name=0\nhighlight_megabytes=1\nhighlight_threads=1\ntree_view=0\nheader_margin=1\ndetailed_cpu_time=0\ncpu_count_from_zero=0\nupdate_process_names=0\naccount_guest_in_cpu_meter=0\ncolor_scheme=0\ndelay=15\nleft_meters=CPU Memory Swap\nleft_meter_modes=1 1 1\nright_meters=Tasks LoadAverage Uptime\nright_meter_modes=2 2 2\nEndOfFile\n"
  }
]