[
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "# Microsoft Open Source Code of Conduct\n\nThis project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).\n\nResources:\n\n- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)\n- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)\n- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns\n"
  },
  {
    "path": "LICENSE",
    "content": "    MIT License\n\n    Copyright (c) Microsoft Corporation.\n\n    Permission is hereby granted, free of charge, to any person obtaining a copy\n    of this software and associated documentation files (the \"Software\"), to deal\n    in the Software without restriction, including without limitation the rights\n    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n    copies of the Software, and to permit persons to whom the Software is\n    furnished to do so, subject to the following conditions:\n\n    The above copyright notice and this permission notice shall be included in all\n    copies or substantial portions of the Software.\n\n    THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n    SOFTWARE\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n\n# MineWorld <br> <sub>A Real-time Interactive World Model on Minecraft</sub>\n\n[![arXiv](https://img.shields.io/badge/arXiv-Paper-red?logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2504.08388) &ensp; [![Project](https://img.shields.io/badge/Project-Page-blue?logo=homepage&logoColor=white)](https://aka.ms/mineworld) &ensp; [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-yellow)](https://huggingface.co/microsoft/mineworld)\n</div>\n\nWe introduce MineWorld, an interactive world model on Minecraft that brings several key advancements over existing approaches: \n* 🕹️ **High generation quality**. Built on a visual-action autoregressive Transformer, MineWorld generates coherent, high-fidelity frames conditioned on both visuals and actions. \n* 🕹️ **Strong controllability**. We propose benchmarks for the action-following capacity, where MineWorld shows precise and consistent behavior. \n* 🕹️ **Fast inference speed**. With Diagonal Decoding, MineWorld achieves a generation rate of 4 to 7 frames per second, enabling real-time interaction in open-ended game environments. \n\nhttps://github.com/user-attachments/assets/2f5b4740-badd-453c-970d-061abd367f82\n\n## 🔥 News\n* May, 2025: The model checkpoints in the [Huggingface repo](https://huggingface.co/microsoft/mineworld) have been temporally taken down.\n* April, 2025: 🚀 [MineWorld](https://github.com/microsoft/mineworld) was released!\n* March, 2025: 🚀 The paper of [Diagonal Decoding](https://arxiv.org/pdf/2503.14070) was released!\n\n## 🔧 Setup\n1. Clone this repository and navigate to MineWorld folder:\n```bash\ngit clone https://github.com/microsoft/mineworld.git\ncd mineworld\n```\n2. We provide an `requirements.txt` file for setting up a pip environment.\n```bash\n# 1. Prepare conda environment\nconda create -n mineworld python=3.10\n# 2. Activate the environment\nconda activate mineworld\n# 3. install our environment\npip3 install -r requirements.txt\n```\n\nWe recommend using high-end GPU for inference. We have done all testing and development using A100 and H100 GPU. \n\n\n## 🎈 Checkpoints\nDownload pre-trained models [here](https://huggingface.co/microsoft/mineworld). Each checkpoint has a corresponding config file with the same name in the `configs` folder in this repository. All models share the same vae checkpoint and config. The data structure is as follows:\n```\n└── checkpoints\n    ├── 300M_16f.ckpt\n    ├── 700M_16f.ckpt\n    ├── 700M_32f.ckpt\n    ├── 1200M_16f.ckpt\n    └── 1200M_32f.ckpt\n    └── vae\n        ├── config.json\n        └── vae.ckpt\n└── validation\n    └── validation.zip\n└── gradio_scene\n    ├── scene.mp4\n    └── scene.jsonl\n```\n\n## 🚀 Inference\nWe provide two ways to use our model: interacting with it in a web demo, and running locally to reproduce the evaluation results in our paper. In addition to download the checkpoints and place them in the `checkpoints` folder, it is also required to download `scene.mp4` and `scene.jsonl` when running the web demo. Make sure they are placed in the same directory.\n\n### Run Web Demo\n\nTo launch the webpage game, run the following command:\n```bash\npython mineworld.py --scene \"path/to/scene.mp4\"    \n    --model_ckpt \"path/to/ckpt\" \n    --config \"path/to/config\" \n```\n\n![image](assets/demo.png)\n\nOnce the demo is running, you can access the website through the local URL or the public URL displayed in the command line. Initialization and the first action may take some time due to compilation.\n\nYou can specify a reference frame using the `--reference_frame` option, which should be larger than `4` and smaller than the context length of the model (i.e., `16` or `32` depending on the model utilized). A higher reference frame number generally corresponds to better visual quality. Once the initial state has been set, perform the game actions by selecting options in each chatbox.\nThe game progresses when pressing the \"Run\" button, displaying the last `8` frames and the most recent frame separately. Players can also set an action count to repeat an action multiple times.\n\nExplanations to the buttons in the web demo are as follows:\n```\nStart frame: select a frame in scene.mp4 with its frame index\nJump to start frame: use the selected frame as the initial state\nCamera `X` and `Y`: control the camera movements between `-90` and `90` degrees\nOther action buttons: same as the actions in Minecraft \nGenerate video: save previous game progress\n```\n\n### Run Local Inference\n\nTo run inference locally, use the following command:\n\n```bash\npython inference.py \\\n        --data_root \"/path/to/validation/dataset\" \\\n        --model_ckpt \"path/to/ckpt\" \\\n        --config \"path/to/config\" \\\n        --demo_num 1 \\\n        --frames 15 \\\n        --accelerate-algo 'naive' \\\n        --top_p 0.8 \\\n        --output_dir \"path/to/output\"\n```\n\nCheck `scripts/inference_16f_models.sh` for examples. To switch between naive autoregressive decoding and diagonal decoding, change the command `--accelerate-algo` to `naive` and `image_diagd` correspondingly. \n\nAfter the inference of a set of videos, you can compute the metrics and reproduce the numerical results in our paper, check and run the following scripts:\n```bash \nbash scripts/setup_metrics.sh # only required in the first time\nbash scripts/compute_metrics.sh \n```\n\nThe evalution outputs will have the following structure: \n```\n└── videos \n    ├── inference_setting1\n        ├── clip_1.mp4\n        └── clip_1.json\n    ├── inference_setting2\n        ├── clip_1.mp4\n        └── clip_1.json\n└── metrics_log\n    ├── fvd_inference_setting1.json\n    ├── fvd_inference_setting2.json\n    ├── idm_inference_setting1.json\n    ├── idm_inference_setting2.json\n    └── latest_metrics.csv \n```\nAll results will be aggregated into `metrics_log/latest_metrics.csv`.\n\n## 💡 Intended Uses\n\nOur model is solely trained in the Minecraft game domain. As a world model, an initial image in the game scene will be provided, and the users should select an action from the action list. Then the model will generate the next scene that takes place the selected action.\n\n\n## 🪧 Out-of-scope Uses\n\nOur models are not specifically designed for any tasks or scenarios other than the Minecraft model. \n\nDevelopers should expect failures in generation results regarding the out-of-scope scenarios. \n\nDevelopers should be aware of and adhere to applicable laws or regulations (including privacy, trade compliance laws, etc.) that are relevant to their use case, and evaluate and mitigate for privacy, safety, and fairness before using within a specific downstream use case, particularly for high-risk scenarios.\n\n## 🤖️ Risks and Limitations \n\nSome of the limitations of this model to be aware of include: \n* Quality of Service: MineWorld is trained solely on Minecraft, so it cannot generate results for other video domains (such as internet video). And the model cannot generate videos with higher resolution.\n* Information Reliability: MineWorld is trained on videos with a fixed resolution, therefore the results may lose detailed information due to the low resolution. \n* MineWorld inherits any biases, errors, or omissions characteristic of its training data, which may be amplified by any AI-generated interpretations. \n* MineWorld was developed for research and experimental purposes. Further testing and validation are needed before considering its application in commercial or real-world scenarios. \n* The input of other images than Minecraft will result in incoherent imagery being created and should not be attempted.\n* Users are responsible for sourcing their datasets legally and ethically. This could include securing appropriate copy rights, ensuring consent for use of audio/images, and/or the anonymization of data prior to use in research.\n\n## ✏️ BibTeX\n\n```bibtex\n@article{guo2025mineworld,\n  title={MineWorld: a Real-Time and Open-Source Interactive World Model on Minecraft}, \n  author={Guo, Junliang and Ye, Yang and He, Tianyu and Wu, Haoyu and Jiang, Yushu and Pearce, Tim and Bian, Jiang}\n  year={2025},\n  journal={arXiv preprint arXiv:2504.08388},\n}\n```\n\n## 🤗 Acknowledgments\nThis codebase borrows code from [VPT](https://github.com/openai/Video-Pre-Training) and [generative-models](https://github.com/Stability-AI/generative-models). We thank them for their efforts and innovations, which have made the development process more efficient and convenient.\n\nThank you to everyone who contributed their wisdom and efforts to this project.\n\n## ☎️ Contact\n\nWe welcome feedback and collaboration from our audience. If you have suggestions, questions, or observe unexpected/offensive behavior in our technology, please contact us through `tianyuhe AT microsoft.com`.\n\n## 📄 Contributing\n\nThis project welcomes contributions and suggestions.  Most contributions require you to agree to a\nContributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us\nthe rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.\n\nWhen you submit a pull request, a CLA bot will automatically determine whether you need to provide\na CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions\nprovided by the bot. You will only need to do this once across all repos using our CLA.\n\nThis project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).\nFor more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or\ncontact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.\n\n\n## 📍 Trademarks\n\nThis project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft \ntrademarks or logos is subject to and must follow \n[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).\nUse of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.\nAny use of third-party trademarks or logos are subject to those third-party's policies.\n"
  },
  {
    "path": "SECURITY.md",
    "content": "<!-- BEGIN MICROSOFT SECURITY.MD V0.0.9 BLOCK -->\n\n## Security\n\nMicrosoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).\n\nIf you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.\n\n## Reporting Security Issues\n\n**Please do not report security vulnerabilities through public GitHub issues.**\n\nInstead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).\n\nIf you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com).  If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).\n\nYou should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). \n\nPlease include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:\n\n  * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)\n  * Full paths of source file(s) related to the manifestation of the issue\n  * The location of the affected source code (tag/branch/commit or direct URL)\n  * Any special configuration required to reproduce the issue\n  * Step-by-step instructions to reproduce the issue\n  * Proof-of-concept or exploit code (if possible)\n  * Impact of the issue, including how an attacker might exploit the issue\n\nThis information will help us triage your report more quickly.\n\nIf you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.\n\n## Preferred Languages\n\nWe prefer all communications to be in English.\n\n## Policy\n\nMicrosoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).\n\n<!-- END MICROSOFT SECURITY.MD BLOCK -->\n"
  },
  {
    "path": "SUPPORT.md",
    "content": "# TODO: The maintainer of this repo has not yet edited this file\r\n\r\n**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?\r\n\r\n- **No CSS support:** Fill out this template with information about how to file issues and get help.\r\n- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.\r\n- **Not sure?** Fill out an intake as though the answer were \"Yes\". CSS will help you decide.\r\n\r\n*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*\r\n\r\n# Support\r\n\r\n## How to file issues and get help  \r\n\r\nThis project uses GitHub Issues to track bugs and feature requests. Please search the existing \r\nissues before filing new issues to avoid duplicates.  For new issues, file your bug or \r\nfeature request as a new Issue.\r\n\r\nFor help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE \r\nFOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER\r\nCHANNEL. WHERE WILL YOU HELP PEOPLE?**.\r\n\r\n## Microsoft Support Policy  \r\n\r\nSupport for this **PROJECT or PRODUCT** is limited to the resources listed above.\r\n"
  },
  {
    "path": "configs/1200M_16f.yaml",
    "content": "model:\n  target: lvm.LlamaLVM\n  params:\n    model_class: lvm.LlamaForCausalLM\n    tokenizer_config:\n      target: vae.VAE\n      params:\n        config_path: checkpoints/vae/config.json\n        ckpt_path: checkpoints/vae/vae.ckpt        \n    transformer_config:\n      target: transformers.LlamaConfig\n      params:\n        max_position_embeddings: 5552\n        hidden_size: 2048\n        intermediate_size: 8192\n        num_attention_heads: 32\n        num_key_value_heads: 4\n        num_hidden_layers: 20\n        rope_theta: 10000.0\n        torch_dtype: bfloat16\n        rms_norm_eps: 1.0e-05\n        vocab_size: 8262\n        attention_bias: false\n        mlp_bias: false"
  },
  {
    "path": "configs/1200M_32f.yaml",
    "content": "model:\n  target: lvm.LlamaLVM\n  params:\n    model_class: lvm.LlamaForCausalLM\n    tokenizer_config:\n      target: vae.VAE\n      params:\n        config_path: checkpoints/vae/config.json\n        ckpt_path: checkpoints/vae/vae.ckpt   \n    transformer_config:\n      target: transformers.LlamaConfig\n      params:\n        max_position_embeddings: 11104\n        hidden_size: 2048\n        intermediate_size: 8192\n        num_attention_heads: 32\n        num_key_value_heads: 4\n        num_hidden_layers: 20\n        rope_theta: 10000.0\n        torch_dtype: bfloat16\n        rms_norm_eps: 1.0e-05\n        vocab_size: 8262\n        attention_bias: false\n        mlp_bias: false"
  },
  {
    "path": "configs/300M_16f.yaml",
    "content": "model:\n  target: lvm.LlamaLVM\n  params:\n    model_class: lvm.LlamaForCausalLM\n    tokenizer_config:\n      target: vae.VAE\n      params:\n        config_path: checkpoints/vae/config.json\n        ckpt_path: checkpoints/vae/vae.ckpt     \n    transformer_config:\n      target: transformers.LlamaConfig\n      params:\n        max_position_embeddings: 5552\n        hidden_size: 1024\n        intermediate_size: 4096\n        num_attention_heads: 16\n        num_key_value_heads: 4\n        num_hidden_layers: 20\n        initializer_range: 0.02\n        rope_theta: 10000.0\n        torch_dtype: bfloat16\n        rms_norm_eps: 1.0e-05\n        vocab_size: 8262\n        attention_bias: false\n        mlp_bias: false\n        token_num: 347\n        image_num: 336\n        frame: 16"
  },
  {
    "path": "configs/700M_16f.yaml",
    "content": "model:\n  target: lvm.LlamaLVM\n  params:\n    model_class: lvm.LlamaForCausalLM\n    tokenizer_config:\n      target: vae.VAE\n      params:\n        config_path: checkpoints/vae/config.json\n        ckpt_path: checkpoints/vae/vae.ckpt       \n    transformer_config:\n      target: transformers.LlamaConfig\n      params:\n        max_position_embeddings: 5552\n        hidden_size: 2048\n        intermediate_size: 4096\n        num_attention_heads: 32\n        num_key_value_heads: 4\n        num_hidden_layers: 20\n        rope_theta: 10000.0\n        torch_dtype: bfloat16\n        rms_norm_eps: 1.0e-05\n        vocab_size: 8262\n        attention_bias: false\n        mlp_bias: false"
  },
  {
    "path": "configs/700M_32f.yaml",
    "content": "model:\n  target: lvm.LlamaLVM\n  params:\n    model_class: lvm.LlamaForCausalLM\n    tokenizer_config:\n      target: vae.VAE\n      params:\n        config_path: checkpoints/vae/config.json\n        ckpt_path: checkpoints/vae/vae.ckpt       \n    transformer_config:\n      target: transformers.LlamaConfig\n      params:\n        max_position_embeddings: 11104\n        hidden_size: 2048\n        intermediate_size: 4096\n        num_attention_heads: 32\n        num_key_value_heads: 4\n        num_hidden_layers: 20\n        rope_theta: 10000.0\n        torch_dtype: bfloat16\n        rms_norm_eps: 1.0e-05\n        vocab_size: 8262\n        attention_bias: false\n        mlp_bias: false\n\n"
  },
  {
    "path": "diagonal_decoding.py",
    "content": "import torch\nfrom typing import Optional\nfrom torch.nn.attention import SDPBackend\n\ndef sample_top_k(logits, temperature: float = 1.0, top_k: Optional[int] = None, vocab_size=8192):\n    \"\"\"\n    Sample from the logits using top-k sampling.\n    Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py\n    \"\"\"\n    # logits: [batch_size, seq_len, vocab_size]\n    if temperature == 0.0:\n        idx_next = torch.argmax(logits[:, -1, :vocab_size], dim=-1, keepdim=True)\n    else:\n        probs = logits_to_probs(logits[:, -1, :vocab_size], temperature, top_k)\n        idx_next = multinomial_sample_one_no_sync(probs)\n    return idx_next\n\ndef multinomial_sample_one_no_sync(probs_sort, dtype=torch.int):\n    \"\"\"\n    Multinomial sampling without a cuda synchronization.\n    Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py\n    \"\"\"\n    q = torch.empty_like(probs_sort).exponential_(1)\n    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=dtype)\n\ndef logits_to_probs(\n    logits,\n    temperature: float = 1.0,\n    top_k: Optional[int] = None,\n):\n    logits = logits / max(temperature, 1e-5)\n\n    if top_k is not None:\n        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n        pivot = v.select(-1, -1).unsqueeze(-1)\n        logits = torch.where(logits < pivot, -float(\"Inf\"), logits)\n    probs = torch.nn.functional.softmax(logits, dim=-1)\n    return probs\n\ndef sample_top_p(logits, temperature, top_p, vocab_size=8192):\n    probs = torch.softmax(logits[:, -1, :vocab_size] / temperature, dim=-1)\n    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)\n    probs_sum = torch.cumsum(probs_sort, dim=-1)\n    mask = probs_sum - probs_sort > top_p\n    probs_sort[mask] = 0.0\n    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))\n    next_token = multinomial_sample_one_no_sync(probs_sort, dtype=torch.int64)\n    next_token = torch.gather(probs_idx, -1, next_token)\n    return next_token\n\ndef sample_n_top_p(logits, temperature, top_p, vocab_size=8192):\n    probs = torch.softmax(logits[:, :, :vocab_size] / temperature, dim=-1)\n    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)\n    probs_sum = torch.cumsum(probs_sort, dim=-1)\n    mask = probs_sum - probs_sort > top_p\n    probs_sort[mask] = 0.0\n    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))\n    next_token = multinomial_sample_one_no_sync(probs_sort, dtype=torch.int64)\n    next_token = torch.gather(probs_idx, -1, next_token)\n    return next_token\n\ndef sample_n_top_k(logits, temperature: float = 1.0, top_k: Optional[int] = None, vocab_size=8192):\n    if temperature == 0.0:\n        # Modify for multiple logits (n items)\n        idx_next = torch.argmax(logits[:, :, :vocab_size], dim=-1, keepdim=True)  # Use all n logits for top-k\n        probs = None\n    else:\n        probs = logits_to_n_probs(logits[:, :, :vocab_size], temperature, top_k)\n        idx_next = multinomial_sample_one_no_sync(probs)\n\n    return idx_next\n\ndef logits_to_n_probs(\n    logits,\n    temperature: float = 1.0,\n    top_k: Optional[int] = None,\n):\n    logits = logits / max(temperature, 1e-5)\n\n    if top_k is not None:\n        v, _ = torch.topk(logits, min(top_k, logits.size(-1)), dim=-1)\n        pivot = v.select(-1, -1).unsqueeze(-1)\n        logits = torch.where(logits < pivot, -float(\"Inf\"), logits)\n    probs = torch.nn.functional.softmax(logits, dim=-1)\n    return probs\n\ndef decode_one_token(\n    model,\n    input_ids: torch.Tensor,\n    position_ids: torch.Tensor,\n    temperature: float = 1.0,\n    top_k: Optional[int] = None,\n    top_p: Optional[float] = None,\n):\n    \"\"\"\n    Decode a single token from the autoregressive model.\n    \"\"\"\n    logits = model(input_ids=input_ids, position_ids=position_ids)\n    if top_p is not None:\n        return sample_top_p(logits, temperature=temperature, top_p=top_p)\n    else:\n        return sample_top_k(logits, temperature=temperature, top_k=top_k)\n    \ndef decode_some_token(\n    model,\n    input_ids: torch.Tensor,\n    position_ids: torch.Tensor,\n    temperature: float = 1.0,\n    top_k: Optional[int] = None,\n    top_p: Optional[float] = None,\n):\n    \"\"\"\n    Decode multi token from the autoregressive model.\n    \"\"\"\n    logits = model(input_ids=input_ids, position_ids=position_ids)\n    if top_p is not None:\n        return sample_n_top_p(logits, temperature=temperature, top_p=top_p)\n    else:\n        return sample_n_top_k(logits, temperature=temperature, top_k=top_k)\n    \ndef decode_n_tokens(\n    model,\n    input_ids: torch.Tensor,\n    position_ids: torch.Tensor,\n    num_generate_tokens: int,\n    temperature: float = 1.0,\n    top_p: Optional[float] = 0.8,\n    top_k: Optional[int] = None,\n    decode_one_token_function=decode_one_token,\n    pixnum: int = 336,\n    actnum: int = 11,\n    **kwargs,\n):\n    \"\"\"\n    Decode n tokens from the autoregressive model.\n    Adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py\n    \"\"\"\n    new_tokens = [input_ids]\n    pos_ = position_ids\n    assert (\n        top_p is None or top_k is None\n    ), \"Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}\"\n\n    for t in range(num_generate_tokens):\n        with torch.nn.attention.sdpa_kernel(\n            SDPBackend.MATH\n        ):  # Actually better for Inductor to codegen attention here\n            next_token = decode_one_token_function(\n                model,\n                input_ids=input_ids,\n                position_ids=position_ids,\n                temperature=temperature,\n                top_k=top_k,\n                top_p=top_p,\n            )\n            pos_ += 1\n            position_ids = pos_\n            new_tokens.append(next_token.clone())\n            input_ids = next_token.clone()\n\n            if (pos_ - pixnum + 1) % (actnum + pixnum) == 0 and t+2 < num_generate_tokens:\n                action = kwargs[\"action\"][ (t+2) // pixnum ]\n                input_ids = torch.cat((input_ids, action), dim=-1)\n                position_ids = torch.tensor([pos_ + _ for _ in range(actnum+1)], dtype=torch.long, device=\"cuda\")\n                pos_ += actnum\n\n    return new_tokens\n\ndef decode_n_tokens_for_gradio(\n    model,\n    input_ids: torch.Tensor,\n    position_ids: torch.Tensor,\n    num_generate_tokens: int,\n    temperature: float = 1.0,\n    top_p: Optional[float] = 0.8,\n    top_k: Optional[int] = None,\n    decode_one_token_function=decode_one_token,\n):\n    \"\"\"\n    Decode n tokens from the autoregressive model.\n    Adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py\n    \"\"\"\n    new_tokens = []\n    assert (\n        top_p is None or top_k is None\n    ), \"Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}\"\n    position_id = position_ids[-1].unsqueeze(0)\n    assert num_generate_tokens % 336 == 1, \"should be pixnum x n + 1 to fill kvcache\"\n    for t in range(num_generate_tokens):\n        with torch.nn.attention.sdpa_kernel(\n            SDPBackend.MATH\n        ):  # Actually better for Inductor to codegen attention here\n            next_token = decode_one_token_function(\n                model,\n                input_ids=input_ids,\n                position_ids=position_ids,\n                temperature=temperature,\n                top_k=top_k,\n                top_p=top_p,\n            )\n            position_id += 1\n            position_ids = position_id\n            new_tokens.append(next_token.clone())\n            input_ids = next_token.clone()\n    return new_tokens[:-1], position_id\n\ndef prefill(\n    model,\n    input_ids: torch.Tensor = None,\n    position_ids: torch.Tensor = None,\n    temperature: float = 1.0,\n    top_k: Optional[int] = None,\n    top_p: Optional[float] = 0.8,\n    **kwargs,\n):\n    logits = model(input_ids=input_ids, position_ids=position_ids)\n    # Only top-p or top-k can be provided\n    assert (\n        top_p is None or top_k is None\n    ), \"Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}\"\n    if top_p is not None:\n        return sample_top_p(logits, temperature=temperature, top_p=top_p)\n    else:\n        return sample_top_k(logits, temperature=temperature, top_k=top_k)\n\ndef img_diagd_prepare_inputs(\n    ongoing_row_list,\n    row_token_num,\n    ongoing_input,\n    prompt,\n    imagenum,\n    pixnum: int = 336,\n    actnum: int = 11,\n    columnnum: int = 24,\n    promptlen: int = 347,\n    **kwargs\n):\n    position_ids = []\n    \n    for i in ongoing_row_list:\n        global_idx = promptlen + i * columnnum + row_token_num[i] - 1 + (imagenum - 1) * (pixnum + actnum)\n        position_ids.append(global_idx)\n\n    if row_token_num[ongoing_row_list[-1]] == 0:\n        append_policy = kwargs.get(\"append_policy\", True)\n        if append_policy:\n            idx_in_input_ids = ongoing_row_list[-1] * columnnum - 1\n            ongoing_input.append(prompt[:, idx_in_input_ids].unsqueeze(-1))\n        else:\n            ongoing_input.append(ongoing_input[-1])\n\n    input_ids = torch.cat(ongoing_input, dim=1)\n    position_ids = torch.tensor(position_ids, device=\"cuda\")\n\n    return input_ids, position_ids\n\ndef img_diagd_decode_n_tokens(\n    model,\n    input_ids: torch.Tensor,\n    position_ids: torch.Tensor,\n    num_generate_tokens: int,\n    temperature: float = 1.0,\n    top_p: Optional[float] = 0.8,\n    top_k: Optional[int] = None,\n    decode_some_token_function=decode_some_token,\n    pixnum: int = 336,\n    actnum: int = 11,\n    columnnum: int = 24,\n    rownum: int = 14,\n    windowsize: int = 2,\n    promptlen: int = 347,\n    **kwargs,\n):\n    assert (\n        top_p is None or top_k is None\n    ), \"Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}\"\n\n    imagenum = 1\n    cur_len = 1\n    num_generate_tokens += 1\n    prompt = kwargs.pop(\"prompt\", None) \n    new_tokens = [input_ids.clone()]\n    row_token_num = torch.zeros((rownum,), dtype=torch.long, device=\"cuda\")\n    row_token_num[0] += 1 \n    ongoing_row_list = [0]\n    ongoing_input = [input_ids.clone()]\n\n    while True:\n        if cur_len >= num_generate_tokens:\n            break\n\n        if cur_len % pixnum == 0 :#and image_start_token_id_index is None: \n            imagenum += 1\n            action = kwargs[\"action\"][cur_len // pixnum]\n            ongoing_input.append(action)\n            input_id = torch.cat(ongoing_input, dim=-1)\n            position_ids = torch.arange(imagenum * (pixnum+actnum) - actnum - 1, imagenum * (pixnum+actnum), device=\"cuda\")\n\n        image_token_num = cur_len % pixnum\n\n        if image_token_num == 1 and row_token_num[0] == windowsize:\n            ongoing_row_list.append(1)\n\n        if image_token_num >= 1:\n            input_id, position_ids = img_diagd_prepare_inputs(ongoing_row_list=ongoing_row_list, ongoing_input = ongoing_input, imagenum=imagenum, row_token_num=row_token_num, promptlen=promptlen, prompt=prompt,**kwargs)  \n            \n        num_new_tokens = input_id.shape[1] if len(ongoing_row_list) > 0 else 1\n        with torch.nn.attention.sdpa_kernel(\n            SDPBackend.MATH\n        ):  # Actually better for Inductor to codegen attention here\n            next_token = decode_some_token_function(\n                model,\n                input_ids=input_id,\n                position_ids=position_ids,\n                temperature=temperature,\n                top_k=top_k,\n                top_p=top_p,\n            )\n            ongoing_input = []\n        if len(ongoing_row_list) == 0:\n            cur_len += 1\n            ongoing_input.append(next_token[:,-1].clone())\n            new_tokens.append(next_token[:,-1].clone())\n            ongoing_row_list.append(0)\n            row_token_num[0] += 1 \n        else:\n            need_remove_row = None\n            cur_len += num_new_tokens\n            for i in range(num_new_tokens):\n                position_in_new_tokens = torch.sum(row_token_num[:(ongoing_row_list[i] + 1)], dim=0) + (imagenum - 1) * pixnum \n                new_tokens.insert(position_in_new_tokens, next_token[:,i].clone())\n                ongoing_input.append(next_token[:,i].clone())\n                row_token_num[ongoing_row_list[i]] += 1\n\n                if row_token_num[ongoing_row_list[i]] == windowsize and ongoing_row_list[i] < rownum - 1:\n                    ongoing_row_list.append(ongoing_row_list[i]+1)\n\n                elif ongoing_row_list[i] == rownum - 1 and row_token_num[ongoing_row_list[i]] == columnnum:\n                    row_token_num = torch.zeros((rownum,), dtype=torch.long, device=\"cuda\")\n                    ongoing_row_list = []\n                    ongoing_input = [next_token[:,i]]\n                    need_remove_row = None\n                    break\n\n                if row_token_num[ongoing_row_list[i]] == columnnum: ## this row is done\n                    ongoing_input.pop()\n                    need_remove_row = ongoing_row_list[i]\n\n            if need_remove_row is not None:\n                ongoing_row_list.remove(need_remove_row)\n    return new_tokens\n\ndef img_diagd_prepare_inputs_for_gradio(\n    ongoing_row_list,\n    row_token_num,\n    ongoing_input,\n    pixnum: int = 336,\n    actnum: int = 11,\n    columnnum: int = 24,\n    promptlen: int = 347,\n):\n    position_ids = []\n    \n    for i in ongoing_row_list:\n        global_idx = promptlen + i * columnnum + row_token_num[i] - 1\n        position_ids.append(global_idx)\n\n    if row_token_num[ongoing_row_list[-1]] == 0:\n        ongoing_input.append(ongoing_input[-1])\n\n    input_ids = torch.cat(ongoing_input, dim=1)\n    position_ids = torch.tensor(position_ids, device=\"cuda\")\n\n    return input_ids, position_ids\n\ndef img_diagd_decode_n_token_for_gradio(\n    model,\n    input_ids: torch.Tensor,\n    position_ids: torch.Tensor,\n    num_generate_tokens: int,\n    temperature: float = 1.0,\n    top_p: Optional[float] = 0.8,\n    top_k: Optional[int] = None,\n    decode_some_token_function=decode_some_token,\n    pixnum: int = 336,\n    columnnum: int = 24,\n    rownum: int = 14,\n    windowsize: int = 2,\n):\n    assert (\n        top_p is None or top_k is None\n    ), \"Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}\"\n\n    cur_len = 0\n    promptlen = position_ids[-1] + 1\n    \n    new_tokens = []\n    row_token_num = torch.zeros((rownum,), dtype=torch.long, device=\"cuda\")\n    ongoing_row_list = []\n    ongoing_input = []\n    \n    while True:\n        if cur_len == num_generate_tokens:\n            break\n\n        image_token_num = cur_len\n\n        if image_token_num == 1 and row_token_num[0] == windowsize:\n            ongoing_row_list.append(1)\n        if image_token_num == 0:\n            input_id = input_ids\n\n        if image_token_num >=1:\n            input_id, position_ids = img_diagd_prepare_inputs_for_gradio(ongoing_row_list=ongoing_row_list, ongoing_input = ongoing_input, row_token_num=row_token_num, promptlen=promptlen)  \n            \n        num_new_tokens = input_id.shape[1] if len(ongoing_row_list) > 0 else 1\n        with torch.nn.attention.sdpa_kernel(\n            SDPBackend.MATH\n        ):  # Actually better for Inductor to codegen attention here\n            next_token = decode_some_token_function(\n                model,\n                input_ids=input_id,\n                position_ids=position_ids,\n                temperature=temperature,\n                top_k=top_k,\n                top_p=top_p,\n            )\n            ongoing_input = []\n        if len(ongoing_row_list) == 0:\n            cur_len += 1\n            ongoing_input.append(next_token[:,-1].clone())\n            new_tokens.append(next_token[:,-1].clone())\n            ongoing_row_list.append(0)\n            row_token_num[0] += 1 \n        else:\n            need_remove_row = None\n            cur_len += num_new_tokens\n            for i in range(num_new_tokens):\n                position_in_new_tokens = torch.sum(row_token_num[:(ongoing_row_list[i] + 1)], dim=0)\n                new_tokens.insert(position_in_new_tokens, next_token[:,i].clone())\n                ongoing_input.append(next_token[:,i].clone())\n                row_token_num[ongoing_row_list[i]] += 1\n\n                if row_token_num[ongoing_row_list[i]] == windowsize and ongoing_row_list[i] < rownum - 1:\n                    ongoing_row_list.append(ongoing_row_list[i]+1)\n\n                elif ongoing_row_list[i] == rownum - 1 and row_token_num[ongoing_row_list[i]] == columnnum:\n                    row_token_num = torch.zeros((rownum,), dtype=torch.long, device=\"cuda\")\n                    ongoing_row_list = []\n                    ongoing_input = [next_token[:,i]]\n                    need_remove_row = None\n                    break\n\n                if row_token_num[ongoing_row_list[i]] == columnnum: ## this row is done\n                    ongoing_input.pop()\n                    need_remove_row = ongoing_row_list[i]\n\n            if need_remove_row is not None:\n                ongoing_row_list.remove(need_remove_row)\n\n    with torch.nn.attention.sdpa_kernel(\n            SDPBackend.MATH\n        ):  # Actually better for Inductor to codegen attention here\n            _ = decode_some_token_function(\n                model,\n                input_ids=next_token[:,-1],\n                position_ids=position_ids+1,\n                temperature=temperature,\n                top_k=top_k,\n                top_p=top_p,\n            )\n    \n    return new_tokens, position_ids+2\n\n\n\ndef vid_diagd_prepare_inputs(\n    ongoing_row_list_v,\n    row_token_num_v,\n    ongoing_input_v,\n    prompt,\n    pixnum: int = 336,\n    actnum: int = 11,\n    rownum: int = 14,\n    columnnum: int = 24,\n    promptlen: int = 347,\n    **kwargs\n):\n    new_frame = False\n    position_ids = []\n\n    for i in ongoing_row_list_v:\n        global_idx = promptlen + i * columnnum + row_token_num_v[i // rownum][i % rownum] -1 + (i // rownum) * actnum\n        position_ids.append(global_idx)\n\n    lastrow = ongoing_row_list_v[-1]\n    if lastrow % rownum == 0 and row_token_num_v[lastrow // rownum][lastrow % rownum] == 0:\n        # WARNING\n        action = kwargs[\"action\"][lastrow // rownum]\n        ongoing_input_v.append(action)\n        position_ids.pop()\n        pos_act = torch.arange( promptlen + (lastrow // rownum) * (pixnum+actnum) - actnum, promptlen + (lastrow // rownum) * (pixnum+actnum), device=\"cuda\")\n        position_ids.extend(pos_act.unbind())\n        new_frame = True\n    elif row_token_num_v[lastrow // rownum][lastrow % rownum] == 0:\n        append_policy = kwargs.get(\"append_policy\", True)\n        if append_policy:\n            idx_in_input_ids = (lastrow % rownum) * columnnum - 1\n            ongoing_input_v.append(prompt[:, idx_in_input_ids].unsqueeze(-1))\n        else:\n            ongoing_input_v.append(ongoing_input_v[-1])\n\n    input_ids = torch.cat(ongoing_input_v, dim=1)\n    position_ids = torch.tensor(position_ids, device=\"cuda\")\n\n    return input_ids, position_ids, new_frame\n\ndef video_diagd_decode_n_tokens(\n    model,\n    input_ids: torch.Tensor,\n    position_ids: torch.Tensor,\n    num_generate_tokens: int,\n    temperature: float = 1.0,\n    top_p: Optional[float] = 0.8,\n    top_k: Optional[int] = None,\n    decode_some_token_function=decode_some_token,\n    pixnum: int = 336,\n    actnum: int = 11,\n    columnnum: int = 24,\n    rownum: int = 14,\n    windowsize: int = 2,\n    promptlen: int = 347,\n    **kwargs,\n):\n    assert (\n        top_p is None or top_k is None\n    ), \"Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}\"\n\n    cur_len = 1\n    num_generate_tokens += 1\n    prompt = kwargs.pop(\"prompt\", None) \n    new_tokens = [input_ids.clone()]\n    row_token_num_v = []\n    ongoing_row_list_v = [0]\n    row_token_num_v.append(torch.zeros((rownum,), dtype=torch.long, device=\"cuda\"))\n    row_token_num_v[0][0] += 1\n    if row_token_num_v[0][0] == windowsize:\n        ongoing_row_list_v.append(1)\n\n    ongoing_input_v = [input_ids.clone()]\n\n    while True:\n        if cur_len >= num_generate_tokens:\n            break\n\n\n        input_id, position_ids, new_frame = vid_diagd_prepare_inputs(ongoing_row_list_v=ongoing_row_list_v, ongoing_input_v = ongoing_input_v, row_token_num_v=row_token_num_v, promptlen=promptlen, prompt=prompt, **kwargs)  \n            \n        num_new_tokens = input_id.shape[1]\n\n        with torch.nn.attention.sdpa_kernel(\n            SDPBackend.MATH\n        ):  # Actually better for Inductor to codegen attention here\n            next_token = decode_some_token_function(\n                model,\n                input_ids=input_id,\n                position_ids=position_ids,\n                temperature=temperature,\n                top_k=top_k,\n                top_p=top_p,\n            )\n            ongoing_input_v = []\n            if new_frame:\n                next_token = torch.cat([next_token[:,:-actnum], next_token[:,-1:]], dim=1)\n                num_new_tokens = num_new_tokens - actnum + 1\n            \n        need_remove_row = None\n\n        cur_len += num_new_tokens\n        for i in range(num_new_tokens):\n            last_frame = torch.stack(row_token_num_v[:ongoing_row_list_v[i] // rownum]).sum() if ongoing_row_list_v[i] // rownum > 0 else torch.tensor(0, dtype=torch.long, device=\"cuda\")\n            position_in_new_tokens = last_frame + torch.sum(row_token_num_v[ongoing_row_list_v[i] // rownum][:(ongoing_row_list_v[i] % rownum + 1)], dim=0)\n                    \n            new_tokens.insert(position_in_new_tokens, next_token[:,i].clone())\n            ongoing_input_v.append(next_token[:,i].clone())\n            row_token_num_v[ongoing_row_list_v[i] // rownum][ongoing_row_list_v[i] % rownum] += 1\n\n            # WARNING\n            if row_token_num_v[ongoing_row_list_v[i] // rownum][ongoing_row_list_v[i] % rownum] == windowsize and ongoing_row_list_v[i] < rownum * (num_generate_tokens//pixnum) - 1:\n                ongoing_row_list_v.append(ongoing_row_list_v[i]+1)\n                if ongoing_row_list_v[-1] % rownum == 0:\n                    row_token_num_v.append(torch.zeros((rownum,), dtype=torch.long, device=\"cuda\"))\n            if row_token_num_v[ongoing_row_list_v[i] // rownum][ongoing_row_list_v[i] % rownum] == columnnum:\n                ongoing_input_v.pop()\n                need_remove_row = ongoing_row_list_v[i]\n\n        if need_remove_row is not None:\n            ongoing_row_list_v.remove(need_remove_row)\n    return new_tokens\n\n"
  },
  {
    "path": "inference.py",
    "content": "import os\nimport cv2\nimport torch\nimport time\nimport numpy as np\nfrom tqdm import tqdm\nfrom rich import print\nfrom PIL import Image\nfrom pathlib import Path\nfrom torch import autocast\nfrom einops import rearrange\nfrom mcdataset import MCDataset\nfrom omegaconf import OmegaConf\nfrom torchvision import transforms\nfrom argparse import ArgumentParser\nfrom utils import load_model, tensor_to_uint8\ntorch.backends.cuda.matmul.allow_tf32 = False\n\nACCELERATE_ALGO = [\n    'naive','image_diagd'\n]\n\nTARGET_SIZE=(224,384)\nTOKEN_PER_IMAGE = 347 # IMAGE = PIX+ACTION\nTOKEN_PER_PIX = 336\n\nsafe_globals = {\"array\": np.array}\n\n\ndef token2video(code_list, tokenizer, save_path, fps, device = 'cuda'):\n    \"\"\"\n    change log:  we don't perform path processing inside functions to enable extensibility\n    save_path: str, path to save the video, expect to endwith .mp4\n    \n    \"\"\"\n    if len(code_list) % TOKEN_PER_PIX != 0:\n        print(f\"code_list length {len(code_list)} is not multiple of {TOKEN_PER_PIX}\")\n        return\n    num_images = len(code_list) // TOKEN_PER_PIX\n    fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n    video = cv2.VideoWriter(save_path, fourcc, fps, (384, 224))\n    for i in range(num_images):\n        code = code_list[i*TOKEN_PER_PIX:(i+1)*TOKEN_PER_PIX]\n        code = torch.tensor([int(x) for x in code], dtype=torch.long).to(device)\n        img = tokenizer.token2image(code) # pixel\n        frame = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)\n        video.write(frame)\n    video.release()\n\ndef get_args():\n    parser = ArgumentParser()\n    parser.add_argument('--data_root', type=str, required=True)\n    parser.add_argument('--model_ckpt', type=str, required=True)\n    parser.add_argument('--config', type=str, required=True)\n    parser.add_argument('--output_dir', type=str, required=True)\n    parser.add_argument('--demo_num', type=int, default=1)\n    parser.add_argument('--frames', type=int, required=True)\n    parser.add_argument('--window_size', type=int, default=2)\n    parser.add_argument('--accelerate-algo', type=str, default='naive', help=f\"Accelerate Algorithm Option: {ACCELERATE_ALGO}\")\n    parser.add_argument('--fps', type=int, default=6)\n    group = parser.add_mutually_exclusive_group(required=True)\n    group.add_argument('--top_k', type=int, help='Use top-k sampling')\n    group.add_argument('--top_p', type=float, help='Use top-p (nucleus) sampling')\n    parser.add_argument('--val_data_num', type=int, default=500, help=\"number of validation data\")\n    args = parser.parse_args()\n    return args\n\n\ndef lvm_generate(args, model, output_dir, demo_video):\n    \"\"\"\n    \"\"\"\n    ### 1. set video input/output path\n    input_mp4_path = os.path.join(args.data_root, demo_video)\n    input_action_path = os.path.join(args.data_root, demo_video.replace('mp4','jsonl'))\n\n    output_mp4_path = str(output_dir / demo_video)\n    output_action_path = output_mp4_path.replace('.mp4', '.jsonl')\n    # backup action \n    os.system(f\"cp {input_action_path} {output_action_path}\")\n    if os.path.exists(output_mp4_path):\n        print(f\"output path {output_mp4_path} exist\")\n        return {}\n    \n    device = model.transformer.device\n    ### 2. load action into list \n    action_list = []\n    mcdataset = MCDataset()\n    with open(input_action_path, 'r') as f:\n        for line in f:\n            line = eval(line.strip(), {\"__builtins__\": None}, safe_globals)\n            line['camera'] = np.array(line['camera'])\n            act_index = mcdataset.get_action_index_from_actiondict(line, action_vocab_offset=8192)\n            action_list.append(act_index)\n    ### 3. load video frames \n    cap = cv2.VideoCapture(input_mp4_path)\n    start_frame = 0\n    end_frame = args.demo_num\n    frames = []\n    for frame_idx in range(start_frame, end_frame):\n        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)\n        ret, frame = cap.read()\n        if not ret:\n            print(f\"Error in reading frame {frame_idx}\")\n            continue\n        cv2.cvtColor(frame, code=cv2.COLOR_BGR2RGB, dst=frame)\n        frame = np.asarray(np.clip(frame, 0, 255), dtype=np.uint8)\n        frame = torch.from_numpy(frame)\n        frames.append(frame)\n    frames = torch.stack(frames, dim=0).to(device)\n    frames = frames.permute(0, 3, 1, 2)\n    frames = frames.float() / 255.0\n    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n    frames = normalize(frames)\n    \n    with torch.no_grad(), torch.autocast(device_type=\"cuda\", dtype=torch.float16):\n        img_index = model.tokenizer.tokenize_images(frames)\n    img_index = rearrange(img_index, '(b t) h w -> b t (h w)', b=1)\n     \n    all_generated_tokens = []\n\n    action_all = action_list[end_frame: end_frame + args.frames]\n    action_all = torch.tensor(action_all).unsqueeze(1).to(device)\n    image_input = rearrange(img_index, 'b t c -> b (t c)')\n    \n\n    start_t = time.time() \n    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16):\n        if args.accelerate_algo == 'naive':\n            outputs = model.transformer.naive_generate( input_ids=image_input, max_new_tokens=TOKEN_PER_PIX*args.frames, action_all=action_all, top_k=args.top_k, top_p=args.top_p)\n        elif args.accelerate_algo == 'image_diagd':\n            outputs = model.transformer.img_diagd_generate(input_ids=image_input, max_new_tokens=TOKEN_PER_PIX*args.frames, action_all=action_all,windowsize = args.window_size, top_k=args.top_k, top_p=args.top_p)\n        else:\n            raise ValueError(f\"Unknown accelerate algorithm {args.accelerate_algo}\")\n    end_t = time.time()\n    all_generated_tokens.extend(outputs.tolist()[0])\n    new_length = len(all_generated_tokens)\n    time_costed = end_t - start_t \n    token_per_sec = new_length / time_costed\n    frame_per_sec = token_per_sec / TOKEN_PER_PIX\n    print(f\"{new_length} token generated; cost {time_costed:.3f} second; {token_per_sec:.3f} token/sec {frame_per_sec:.3f} fps\")\n    token2video(all_generated_tokens, model.tokenizer, str(output_path / demo_video), args.fps, device)  \n    # return for evaluation \n    return_item = {\n        \"time_costed\": time_costed,\n        \"token_num\": new_length,\n    }\n    return return_item\nif __name__ == '__main__':\n    args = get_args()\n    config = OmegaConf.load(args.config)\n    output_path = Path(args.output_dir)\n    precision_scope = autocast\n    os.makedirs(output_path, exist_ok=True)\n\n    model = load_model(config, args.model_ckpt, gpu=True, eval_mode=True)\n    print(f\"[bold magenta][MINEWORLD][INFERENCE][/bold magenta] Load Model From {args.model_ckpt}\")\n    # get accelearte algoritm\n    args.accelerate_algo = args.accelerate_algo.lower()\n    if args.accelerate_algo not in ACCELERATE_ALGO:\n        print(f\"[bold red][Warning][/bold red] {args.accelerate_algo} is not in {ACCELERATE_ALGO}, use naive\")\n        args.accelerate_algo = 'naive'\n    num_item = 0\n    for root, _, files in os.walk(args.data_root):\n        files = [f for f in files  if f.endswith('.mp4')] # mp4 would not influence progress bar \n        files = sorted(files, key=lambda x: int(x.split('_')[1].split('.')[0]))\n        for file in tqdm(files):\n            return_item = lvm_generate(args, model, output_path,file)\n            num_item += 1\n            if num_item  >= args.val_data_num:\n                print(f\"[bold magenta][MINEWORLD][INFERENCE][/bold magenta]  reach val data num limit {args.val_data_num}\")\n                break"
  },
  {
    "path": "lvm.py",
    "content": "\"\"\"\n    Wrap the Huggingface Transformers Llama to PyTorch Lightning Module.\n\"\"\"\nimport os\nimport sys\nimport inspect \nimport torch\nfrom typing import Optional\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.utils import logging\nfrom transformers import LlamaConfig\nfrom utils import get_obj_from_str, instantiate_from_config\nfrom diagonal_decoding import decode_one_token, decode_some_token, decode_n_tokens, decode_n_tokens_for_gradio, prefill, img_diagd_decode_n_tokens, video_diagd_decode_n_tokens, img_diagd_decode_n_token_for_gradio\ntorch.backends.cuda.matmul.allow_tf32 = False\n\nlogger = logging.get_logger(__name__)\ncurrentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))\nparentdir = os.path.dirname(currentdir)\n\nif not (parentdir in sys.path):\n    sys.path.insert(0, parentdir) \n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids):\n    \"\"\"\n    Apply rotary position embeddings to query and key tensors.\n\n    Args:\n        q (torch.Tensor): Query tensor.\n        k (torch.Tensor): Key tensor.\n        cos (torch.Tensor): Cosine values.\n        sin (torch.Tensor): Sine values.\n        position_ids (torch.Tensor): Position IDs.\n\n    Returns:\n        torch.Tensor: Query and key tensors with rotary position embeddings applied.\n    \"\"\"\n    cos = cos[position_ids].unsqueeze(0).unsqueeze(2)\n    sin = sin[position_ids].unsqueeze(0).unsqueeze(2)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\nclass LlamaLVM(torch.nn.Module):\n    def __init__(\n        self,\n        transformer_config,\n        model_class: str,\n        tokenizer_config = None,\n    ):\n        super().__init__()\n        self.config = instantiate_from_config(transformer_config)\n        self.transformer = get_obj_from_str(model_class)(self.config)\n        if tokenizer_config is not None:\n            self.tokenizer = instantiate_from_config(tokenizer_config)\n\nclass LlamaRMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        LlamaRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.variance_epsilon}\"\n\nclass LlamaRotaryEmbedding(nn.Module):\n    def __init__(\n        self,\n        device=None,\n        config: Optional[LlamaConfig] = None,\n    ):\n        super().__init__()\n        self.rope_kwargs = {}\n        self.rope_type = \"default\"\n        self.config = config\n        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]\n        self.max_position_embeddings = config.max_position_embeddings\n        inv_freq, _ = self.rope_init_fn(self.config, device, **self.rope_kwargs)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self._set_cos_sin_cache(\n            device=self.inv_freq.device,\n            dtype=torch.get_default_dtype(),\n        )\n    \n    def _set_cos_sin_cache(self, device, dtype):\n        \"\"\"\n        Set the cosine and sine cache for positional embeddings.\n\n        Args:\n            seq_len (int): The sequence length.\n            device (str): The device on which the cache tensors will be stored.\n            dtype: The data type of the cache tensors.\n        \"\"\"\n        t = torch.arange(\n            self.max_position_embeddings, device=device, dtype=self.inv_freq.dtype\n        )\n\n        freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\n            \"cos_cached\", emb.cos().to(dtype), persistent=False\n        )\n        self.register_buffer(\n            \"sin_cached\", emb.sin().to(dtype), persistent=False\n        )\n    \n    def forward(self, x, seq_len=None):\n        \"\"\"\n        Forward pass of the LlamaRotaryEmbedding module.\n\n        Args:\n            x (torch.Tensor): Input tensor of shape [bs, num_attention_heads, seq_len, head_size].\n            seq_len (int): The sequence length. If greater than the cached length, the cache will be updated.\n\n        Returns:\n            tuple: A tuple containing two tensors, the cosine and sine embeddings, both of shape [1, 1, seq_len, dim].\n        \"\"\"\n        if seq_len > self.max_position_embeddings:\n            raise ValueError(\"seq length should less than max embedding\")\n\n        return (\n            self.cos_cached[:seq_len, :].to(dtype=x.dtype),\n            self.sin_cached[:seq_len, :].to(dtype=x.dtype),\n        )\n\nclass LlamaMLP(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n  \nclass LlamaAttention(nn.Module):\n    def __init__(self, config: LlamaConfig):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = getattr(config, \"head_dim\", self.hidden_size // self.num_heads)\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.max_position_embeddings = config.max_position_embeddings\n        assert (self.head_dim * self.num_heads) == self.hidden_size, \"hidden_size must be divisible by num_heads\"\n        self.q_proj = nn.Linear(\n            self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.k_proj = nn.Linear(\n            self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.v_proj = nn.Linear(\n            self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.o_proj = nn.Linear(\n            self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias\n        )\n        self.rotary_emb = LlamaRotaryEmbedding(config=config)\n        self.max_batch_size = getattr(config, \"max_batch_size\", 1)\n        self.init_kv_cache()\n\n    def init_kv_cache(self, dtype=torch.float16):\n        cache_shape = (self.max_batch_size, self.max_position_embeddings, self.num_key_value_heads, self.head_dim)\n        self.cache_k = torch.zeros(cache_shape, dtype=dtype).cuda()\n        self.cache_v = torch.zeros(cache_shape, dtype=dtype).cuda()\n\n    def forward(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.LongTensor] = None,\n            positions_embedding = None,\n    ):\n        \n        bsz, q_len, _ = hidden_states.size()\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(\n            bsz, q_len, self.num_heads, self.head_dim\n        )\n        key_states = key_states.view(\n            bsz, q_len, self.num_key_value_heads, self.head_dim\n        )\n        value_states = value_states.view(\n            bsz, q_len, self.num_key_value_heads, self.head_dim\n        )\n\n        cos, sin = positions_embedding\n\n\n        query_states, key_states = apply_rotary_pos_emb(\n            query_states, key_states, cos, sin, position_ids\n        )\n        \n        self.cache_k[:bsz, position_ids] = key_states\n        self.cache_v[:bsz, position_ids] = value_states\n        key_states, value_states = (\n                self.cache_k[:bsz, :, :],\n                self.cache_v[:bsz, :, :],\n            )\n\n        key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=2)\n        value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=2)\n\n        query_states, key_states, value_states = map(lambda x: x.transpose(1, 2), (query_states, key_states, value_states))\n        \n        attn_output = torch.nn.functional.scaled_dot_product_attention(\n            query_states,\n            key_states,\n            value_states,\n            attn_mask=attention_mask,\n            dropout_p=0.0,\n            is_causal=False,\n        ).transpose(1, 2).contiguous()\n\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\nclass LlamaDecoderLayer(nn.Module):\n    def __init__(self, config: LlamaConfig):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.self_attn = LlamaAttention(config=config)\n        self.mlp = LlamaMLP(config)\n        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def forward(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.LongTensor] = None,\n            positions_embedding = None,\n    ):\n        \"\"\"\n        Forward pass for the LlamaDecoderLayer.\n\n        Args:\n            hidden_states (torch.FloatTensor): Input tensor of shape `(batch, seq_len, embed_dim)`.\n            attention_mask (torch.FloatTensor, optional): Attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            position_ids (torch.LongTensor, optional): Positional IDs tensor.\n\n\n        Returns:\n            Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: Tuple containing:\n                - hidden_states (torch.FloatTensor): Output tensor.\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            positions_embedding=positions_embedding,\n        )\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n\nclass LlamaModel(PreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n    Args:\n        config: LlamaConfig\n    \"\"\"\n\n    def __init__(self, config: LlamaConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList(\n            [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.rotary_emb = LlamaRotaryEmbedding(config=config)\n        self.max_position_embedding = config.max_position_embeddings\n        self.causal_mask = torch.tril(\n            torch.ones(self.max_position_embedding, self.max_position_embedding, dtype=torch.bool)\n        ).cuda()\n        self.post_init()\n        \n    def _create_attention_mask(self, input_pos: Optional[torch.Tensor]):\n        \"\"\"\n        Creates an attention mask for the transformer layers.\n\n        Args:\n            input_pos[torch.Tensor]: The position of input sequence (used for inference only).\n\n        Returns:\n            Optional[torch.Tensor]: The attention mask, or None for causal mask.\n        \"\"\"\n        mask = self.causal_mask[input_pos]\n        return mask\n\n    def forward(\n            self,\n            input_ids: torch.LongTensor = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.LongTensor] = None,\n    ):\n\n        if input_ids is None:\n            raise ValueError(\n                \"decoder_input_ids is None\"\n            )\n        hidden_states = self.embed_tokens(input_ids)\n        \n        positions_embedding = self.rotary_emb(hidden_states, seq_len=self.max_position_embedding)\n\n        attention_mask = self._create_attention_mask(input_pos=position_ids)\n        for idx, decoder_layer in enumerate(self.layers):\n            layer_outputs = decoder_layer(\n                hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                positions_embedding=positions_embedding,\n            )\n\n            hidden_states = layer_outputs\n\n        hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n        \nclass LlamaForCausalLM(PreTrainedModel):\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = LlamaModel(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.post_init()      \n\n    def forward(\n            self,\n            input_ids: torch.LongTensor = None,\n            position_ids: Optional[torch.LongTensor] = None,\n    ):\n\n        outputs = self.model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n        )\n        logits = self.lm_head(outputs[:, :, :])\n        return logits\n    def refresh_kvcache(self):\n        for i in self.model.layers:\n            i.self_attn.init_kv_cache()\n\n    def naive_generate(self, input_ids, max_new_tokens, temperature=1.0, action_all=None, top_p=None, top_k=None):\n\n        self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True)\n        if action_all is not None:\n            input_ids = torch.cat([input_ids, action_all[0]], dim=-1)\n        position_ids = torch.arange(0, input_ids.shape[1], device=\"cuda\")\n        next_token = self.prefill(\n            self,\n            input_ids=input_ids,\n            position_ids=position_ids,\n            temperature=temperature,\n            top_k = top_k,\n            top_p = top_p,\n        )\n\n        self.decode_one_token = torch.compile(decode_one_token, mode=\"max-autotune\", fullgraph=True)\n        position_ids = torch.tensor([input_ids.shape[1]], dtype=torch.long, device=\"cuda\")\n        \n        generated_tokens = decode_n_tokens(\n            self,\n            input_ids = next_token.view(1, -1),\n            position_ids = position_ids,\n            num_generate_tokens = max_new_tokens - 1,\n            temperature = temperature,\n            decode_one_token_function=self.decode_one_token,\n            action=action_all,\n            top_p = top_p,\n            top_k = top_k,\n        )\n        return torch.cat(generated_tokens, dim=1)\n    \n    def prefill_for_gradio(self, input_ids, temperature=1.0):\n        self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True)\n        last_pos = input_ids.shape[1]\n        position_ids = torch.arange(0, last_pos, device=\"cuda\")\n        with torch.no_grad(), torch.autocast(\"cuda\", dtype=torch.float16):\n            next_token = self.prefill(\n                self,\n                input_ids=input_ids,\n                position_ids=position_ids,\n                temperature=temperature,\n            )\n        return next_token, last_pos\n    \n    def decode_img_token_for_gradio(self, input_action, position_id, max_new_tokens, temperature=1.0):\n        self.decode_one_token = torch.compile(decode_one_token, mode=\"max-autotune\", fullgraph=True)\n        # self.decode_one_token = decode_one_token\n        # WARNING\n        position_ids = torch.arange(position_id, position_id + input_action.shape[1], device=\"cuda\")\n        with torch.no_grad(), torch.autocast(\"cuda\", dtype=torch.float16):\n            generated_tokens, position_id = decode_n_tokens_for_gradio(\n                self,\n                input_ids = input_action,\n                position_ids = position_ids,\n                num_generate_tokens = max_new_tokens,\n                temperature = temperature,\n                decode_one_token_function=self.decode_one_token,\n            )\n        # WARNING\n        return generated_tokens, position_id\n    \n    def diagd_img_token_for_gradio(self, input_action, position_id, max_new_tokens, temperature=1.0, windowsize=2):\n        self.decode_some_token = torch.compile(decode_some_token, mode=\"max-autotune\", fullgraph=True)\n        position_ids = torch.arange(position_id, position_id + input_action.shape[1], device=\"cuda\")\n        with torch.no_grad(), torch.autocast(\"cuda\", dtype=torch.float16):\n            generated_tokens, position_id = img_diagd_decode_n_token_for_gradio(\n                self,\n                input_ids = input_action,\n                position_ids = position_ids,\n                num_generate_tokens = max_new_tokens,\n                temperature = temperature,\n                decode_some_token_function=self.decode_some_token,\n                windowsize = windowsize,\n            )\n        return generated_tokens, position_id\n\n\n    def img_diagd_generate(self, input_ids, max_new_tokens, temperature=1.0, action_all=None, windowsize=2, top_p=None, top_k=None):\n\n        self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True)\n        input_ids = torch.cat([input_ids, action_all[0]], dim=-1)\n        position_ids = torch.arange(0, input_ids.shape[1], device=\"cuda\")\n        next_token = self.prefill(\n            self,\n            input_ids=input_ids,\n            position_ids=position_ids,\n            temperature=temperature,\n            top_k = top_k,\n            top_p = top_p,\n        )\n\n        self.decode_some_token = torch.compile(decode_some_token, mode=\"max-autotune\", fullgraph=True)\n        position_ids = torch.tensor([input_ids.shape[1]], dtype=torch.long, device=\"cuda\")\n\n        generated_tokens = img_diagd_decode_n_tokens(\n            self,\n            input_ids = next_token.view(1, -1),\n            position_ids = position_ids,\n            num_generate_tokens = max_new_tokens - 1,\n            temperature = temperature,\n            decode_some_token_function=self.decode_some_token,\n            windowsize = windowsize,\n            action=action_all,\n            prompt=input_ids,\n            top_k = top_k,\n            top_p = top_p,\n        )\n        return torch.cat(generated_tokens, dim=1)\n    \n    def vid_diagd_generate(self, input_ids, max_new_tokens,windowsize=2, temperature=1.0, action_all=None,**kwargs):\n\n        self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True)\n        input_ids = torch.cat([input_ids, action_all[0]], dim=-1)\n        position_ids = torch.arange(0, input_ids.shape[1], device=\"cuda\")\n        next_token = self.prefill(\n            self,\n            input_ids=input_ids,\n            position_ids=position_ids,\n            temperature=temperature,\n        )\n\n        self.decode_some_token = torch.compile(decode_some_token, mode=\"max-autotune\", fullgraph=True)\n        # self.decode_some_token = decode_some_token\n        position_ids = torch.tensor([input_ids.shape[1]], dtype=torch.long, device=\"cuda\")\n\n        generated_tokens = video_diagd_decode_n_tokens(\n            self,\n            input_ids = next_token.view(1, -1),\n            position_ids = position_ids,\n            num_generate_tokens = max_new_tokens - 1,\n            temperature = temperature,\n            decode_some_token_function=self.decode_some_token,\n            windowsize = windowsize,\n            action=action_all,\n            prompt=input_ids,\n            **kwargs\n        )\n        return torch.cat(generated_tokens, dim=1)\n\n\n"
  },
  {
    "path": "mcdataset.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\nimport os\nimport json\nimport attr\nimport collections\nimport numpy as np\nfrom typing import Union, Dict\nimport torch\nfrom utils import print0\n\n\n# https://github.com/openai/Video-Pre-Training/blob/aed46b90e8db2332801feabd8be2de01f92c0ad2/run_inverse_dynamics_model.py#L17\nKEYBOARD_BUTTON_MAPPING = {\n    \"key.keyboard.escape\" :\"ESC\",\n    \"key.keyboard.s\" :\"back\",\n    \"key.keyboard.q\" :\"drop\",\n    \"key.keyboard.w\" :\"forward\",\n    \"key.keyboard.1\" :\"hotbar.1\",\n    \"key.keyboard.2\" :\"hotbar.2\",\n    \"key.keyboard.3\" :\"hotbar.3\",\n    \"key.keyboard.4\" :\"hotbar.4\",\n    \"key.keyboard.5\" :\"hotbar.5\",\n    \"key.keyboard.6\" :\"hotbar.6\",\n    \"key.keyboard.7\" :\"hotbar.7\",\n    \"key.keyboard.8\" :\"hotbar.8\",\n    \"key.keyboard.9\" :\"hotbar.9\",\n    \"key.keyboard.e\" :\"inventory\",\n    \"key.keyboard.space\" :\"jump\",\n    \"key.keyboard.a\" :\"left\",\n    \"key.keyboard.d\" :\"right\",\n    \"key.keyboard.left.shift\" :\"sneak\",\n    \"key.keyboard.left.control\" :\"sprint\",\n    \"key.keyboard.f\" :\"swapHands\",\n}\n\n# https://github.com/openai/Video-Pre-Training/blob/aed46b90e8db2332801feabd8be2de01f92c0ad2/run_inverse_dynamics_model.py#L41\n# Template action\nNOOP_ACTION = {\n    \"ESC\": 0,\n    \"back\": 0,\n    \"drop\": 0,\n    \"forward\": 0,\n    \"hotbar.1\": 0,\n    \"hotbar.2\": 0,\n    \"hotbar.3\": 0,\n    \"hotbar.4\": 0,\n    \"hotbar.5\": 0,\n    \"hotbar.6\": 0,\n    \"hotbar.7\": 0,\n    \"hotbar.8\": 0,\n    \"hotbar.9\": 0,\n    \"inventory\": 0,\n    \"jump\": 0,\n    \"left\": 0,\n    \"right\": 0,\n    \"sneak\": 0,\n    \"sprint\": 0,\n    \"swapHands\": 0,\n    \"camera\": np.array([0, 0]),  # [y, x]\n    \"attack\": 0,\n    \"use\": 0,\n    \"pickItem\": 0,\n}\n\nOASIS_ACTION_KEYS = [\n    \"inventory\",\n    \"ESC\",\n    \"hotbar.1\",\n    \"hotbar.2\",\n    \"hotbar.3\",\n    \"hotbar.4\",\n    \"hotbar.5\",\n    \"hotbar.6\",\n    \"hotbar.7\",\n    \"hotbar.8\",\n    \"hotbar.9\",\n    \"forward\",\n    \"back\",\n    \"left\",\n    \"right\",\n    \"cameraX\",\n    \"cameraY\",\n    \"jump\",\n    \"sneak\",\n    \"sprint\",\n    \"swapHands\",\n    \"attack\",\n    \"use\",\n    \"pickItem\",\n    \"drop\",\n]\n\n\n# Matches a number in the MineRL Java code regarding sensitivity\n# This is for mapping from recorded sensitivity to the one used in the model\nCAMERA_SCALER = 360.0 / 2400.0\n\n\n# https://github.com/openai/Video-Pre-Training/blob/main/lib/actions.py#L8 with some modifications\nclass Buttons:\n    # 14 in total without hotbar and camera\n    ATTACK = \"attack\"\n    BACK = \"back\"\n    FORWARD = \"forward\"\n    JUMP = \"jump\"\n    LEFT = \"left\"\n    RIGHT = \"right\"\n    SNEAK = \"sneak\"\n    SPRINT = \"sprint\"\n    USE = \"use\"\n    DROP = \"drop\"\n    INVENTORY = \"inventory\"\n    # added by Yang\n    ESC = \"ESC\"\n    SWAPHANDS = \"swapHands\"\n    PICKITEM = \"pickItem\"\n\n    ALL = [\n        USE,\n        ATTACK,\n\n        \n        FORWARD,\n        BACK,\n        LEFT,\n        RIGHT,\n\n        JUMP,\n        SNEAK,\n        SPRINT,\n        \n        DROP,\n        SWAPHANDS,\n        PICKITEM,\n\n        INVENTORY,\n        ESC,\n    ] + [f\"hotbar.{i}\" for i in range(1, 10)]\n\n\nclass QuantizationScheme:\n    LINEAR = \"linear\"\n    MU_LAW = \"mu_law\"\n\n\n# https://github.com/openai/Video-Pre-Training/blob/main/lib/actions.py#L49\n@attr.s(auto_attribs=True)\nclass CameraQuantizer:\n    \"\"\"\n    A camera quantizer that discretizes and undiscretizes a continuous camera input with y (pitch) and x (yaw) components.\n\n    Parameters:\n    - camera_binsize: The size of the bins used for quantization. In case of mu-law quantization, it corresponds to the average binsize.\n    - camera_maxval: The maximum value of the camera action.\n    - quantization_scheme: The quantization scheme to use. Currently, two quantization schemes are supported:\n    - Linear quantization (default): Camera actions are split uniformly into discrete bins\n    - Mu-law quantization: Transforms the camera action using mu-law encoding (https://en.wikipedia.org/wiki/%CE%9C-law_algorithm)\n    followed by the same quantization scheme used by the linear scheme.\n    - mu: Mu is the parameter that defines the curvature of the mu-law encoding. Higher values of\n    mu will result in a sharper transition near zero. Below are some reference values listed\n    for choosing mu given a constant maxval and a desired max_precision value.\n    maxval = 10 | max_precision = 0.5  | μ ≈ 2.93826\n    maxval = 10 | max_precision = 0.4  | μ ≈ 4.80939\n    maxval = 10 | max_precision = 0.25 | μ ≈ 11.4887\n    maxval = 20 | max_precision = 0.5  | μ ≈ 2.7\n    maxval = 20 | max_precision = 0.4  | μ ≈ 4.39768\n    maxval = 20 | max_precision = 0.25 | μ ≈ 10.3194\n    maxval = 40 | max_precision = 0.5  | μ ≈ 2.60780\n    maxval = 40 | max_precision = 0.4  | μ ≈ 4.21554\n    maxval = 40 | max_precision = 0.25 | μ ≈ 9.81152\n    \"\"\"\n\n    camera_maxval: int\n    camera_binsize: int\n    quantization_scheme: str = attr.ib(\n        default=QuantizationScheme.LINEAR,\n        validator=attr.validators.in_([QuantizationScheme.LINEAR, QuantizationScheme.MU_LAW]),\n    )\n    mu: float = attr.ib(default=5)\n\n    def discretize(self, xy):\n        xy = np.clip(xy, -self.camera_maxval, self.camera_maxval)\n\n        if self.quantization_scheme == QuantizationScheme.MU_LAW:\n            xy = xy / self.camera_maxval\n            v_encode = np.sign(xy) * (np.log(1.0 + self.mu * np.abs(xy)) / np.log(1.0 + self.mu))\n            v_encode *= self.camera_maxval\n            xy = v_encode\n\n        # Quantize using linear scheme\n        return np.round((xy + self.camera_maxval) / self.camera_binsize).astype(np.int64)\n\n    def undiscretize(self, xy):\n        xy = xy * self.camera_binsize - self.camera_maxval\n\n        if self.quantization_scheme == QuantizationScheme.MU_LAW:\n            xy = xy / self.camera_maxval\n            v_decode = np.sign(xy) * (1.0 / self.mu) * ((1.0 + self.mu) ** np.abs(xy) - 1.0)\n            v_decode *= self.camera_maxval\n            xy = v_decode\n        return xy\n\n\nclass MCDataset(torch.utils.data.Dataset):\n    \"\"\"\n    Dataset for Minecraft.\n    \"\"\"\n    def __init__(self,\n                 action_length: int = 11,  # including bos and eos\n                 camera_binsize: int = 9,  # 2 in vpt\n                 camera_maxval: int = 90,  # 10 in vpt\n                 camera_mu: float = 11.4887,  # 10 in vpt\n                 quantization_scheme: str = \"mu_law\",\n    ):\n        self.action_length = action_length\n        self.camera_quantizer = CameraQuantizer(\n            camera_binsize=camera_binsize,\n            camera_maxval=camera_maxval,\n            mu=camera_mu,\n            quantization_scheme=quantization_scheme,\n        )\n\n    def json_action_to_env_action(self, json_action):\n        \"\"\"\n        https://github.com/openai/Video-Pre-Training/blob/aed46b90e8db2332801feabd8be2de01f92c0ad2/run_inverse_dynamics_model.py#L80\n        Converts a json action into a MineRL action.\n        Returns (minerl_action, is_null_action)\n        \"\"\"\n        # This might be slow...\n        env_action = NOOP_ACTION.copy()\n        # As a safeguard, make camera action again so we do not override anything\n        env_action[\"camera\"] = np.array([0, 0])\n\n        is_null_action = True\n        keyboard_keys = json_action[\"keyboard\"][\"keys\"]\n        for key in keyboard_keys:\n            # You can have keys that we do not use, so just skip them\n            # NOTE in original training code, ESC was removed and replaced with\n            #      \"inventory\" action if GUI was open.\n            #      Not doing it here, as BASALT uses ESC to quit the game.\n            if key in KEYBOARD_BUTTON_MAPPING:\n                env_action[KEYBOARD_BUTTON_MAPPING[key]] = 1\n                is_null_action = False\n\n        mouse = json_action[\"mouse\"]\n        camera_action = env_action[\"camera\"]\n        camera_action[0] = mouse[\"dy\"] * CAMERA_SCALER\n        camera_action[1] = mouse[\"dx\"] * CAMERA_SCALER\n\n        if mouse[\"dx\"] != 0 or mouse[\"dy\"] != 0:\n            is_null_action = False\n        else:\n            if abs(camera_action[0]) > 180:\n                camera_action[0] = 0\n            if abs(camera_action[1]) > 180:\n                camera_action[1] = 0\n\n        mouse_buttons = mouse[\"buttons\"]\n        if 0 in mouse_buttons:\n            env_action[\"attack\"] = 1\n            is_null_action = False\n        if 1 in mouse_buttons:\n            env_action[\"use\"] = 1\n            is_null_action = False\n        if 2 in mouse_buttons:\n            env_action[\"pickItem\"] = 1\n            is_null_action = False\n\n        # added by Yang\n        # if two confictory actions are pressed, remove them\n        if env_action[\"forward\"] == 1 and env_action[\"back\"] == 1:\n            env_action[\"forward\"] = 0\n            env_action[\"back\"] = 0\n        if env_action[\"left\"] == 1 and env_action[\"right\"] == 1:\n            env_action[\"left\"] = 0\n            env_action[\"right\"] = 0 \n        if env_action[\"jump\"] == 1 and env_action[\"sneak\"] == 1:\n            env_action[\"jump\"] = 0\n            env_action[\"sneak\"] = 0\n        if env_action[\"sprint\"] == 1 and env_action[\"sneak\"] == 1:\n            env_action[\"sprint\"] = 0\n            env_action[\"sneak\"] = 0\n        if env_action[\"attack\"] == 1 and env_action[\"use\"] == 1:\n            env_action[\"attack\"] = 0\n            env_action[\"use\"] = 0\n\n        # remove inventory and ESC action\n        if env_action[\"inventory\"] == 1:\n            is_null_action = True\n        if env_action[\"ESC\"] == 1:\n            is_null_action = True\n\n        return env_action, is_null_action\n\n    def make_action_vocab(self,\n                          num_cam_bins: int = 21,\n                          action_vocab_offset: int = 0,\n                          verbose: bool = False):\n        action_vocab = collections.OrderedDict()\n        # 14 actions and hotbar.1-9\n        for i, action in enumerate(Buttons.ALL):\n            action_vocab[action] = i\n        # camera 0 \n        for i in range(num_cam_bins):\n            action_vocab[f\"cam_0_{i}\"] = len(Buttons.ALL) + i\n        # camera 1\n        for i in range(num_cam_bins):\n            action_vocab[f\"cam_1_{i}\"] = len(Buttons.ALL) + num_cam_bins + i\n        # bos, null, eos\n        action_vocab[\"<act_bos>\"] = len(Buttons.ALL) + 2 * num_cam_bins\n        action_vocab[\"<null_act>\"] = len(Buttons.ALL) + 2 * num_cam_bins + 1\n        action_vocab[\"<act_eos>\"] = len(Buttons.ALL) + 2 * num_cam_bins + 2\n\n        if action_vocab_offset > 0:\n            action_vocab = {k: v + action_vocab_offset for k, v in action_vocab.items()}\n\n        if verbose:\n            print0(f\"[bold yellow]\\[MCDataset][/bold yellow] Action Vocab: {action_vocab}\")\n\n        self.action_vocab = action_vocab\n        # return action_vocab\n\n    def _handle_conflict_action_index(self,\n                                action_dict: Dict[str, Union[int, np.ndarray]],\n                                key1: str,\n                                key2: str,\n                                null_key: str,\n                                verbose: bool = False):\n        if action_dict[key1] == 1 and action_dict[key2] == 1:\n            if verbose:\n                print0(f\"[bold yellow]\\[MCDataset][/bold yellow] {key1} and {key2} are both pressed\")\n            return self.action_vocab[null_key]\n        elif action_dict[key1] == 1:\n            return self.action_vocab[key1]\n        elif action_dict[key2] == 1:\n            return self.action_vocab[key2]\n        else:\n            return self.action_vocab[null_key]\n\n    def get_action_index_from_actiondict(self,\n                                         action_dict: Dict[str, Union[int, np.ndarray]],\n                                         action_vocab_offset: int = 0,\n                                         verbose: bool = False):\n\n        if not hasattr(self, \"action_vocab\"):\n            self.make_action_vocab(action_vocab_offset=action_vocab_offset, verbose=verbose)\n\n        # action_list = [boa, camy, camx, hotbar, fore_back, left_right, sprint_sneak, use_attack, jump, drop_pick, eoa]\n        # 11 actions\n        action_list = [self.action_vocab[\"<null_act>\"]] * self.action_length\n        # 0 & 10\n        action_list[0] = self.action_vocab[\"<act_bos>\"]\n        action_list[-1] = self.action_vocab[\"<act_eos>\"]\n\n        camera_action = action_dict[\"camera\"]\n        assert len(camera_action) == 2, f\"[MCDataset] camera_action length is not 2: {camera_action}\"\n        # camera_action should be numpy array\n        if not isinstance(camera_action, np.ndarray):\n            camera_action = np.array(camera_action)\n        camera_action = self.camera_quantizer.discretize(camera_action)\n        # 1 & 2\n        action_list[1] = self.action_vocab[f\"cam_0_{camera_action[0]}\"]\n        action_list[2] = self.action_vocab[f\"cam_1_{camera_action[1]}\"]\n\n        # 3\n        for i in range(1, 10):\n            if f\"hotbar.{i}\" in action_dict and action_dict[f\"hotbar.{i}\"] == 1:\n                action_list[3] = self.action_vocab[f\"hotbar.{i}\"]\n                break\n\n        # 4 forward/back\n        action_list[4] = self._handle_conflict_action_index(action_dict, \"forward\", \"back\", \"<null_act>\", verbose=verbose)\n        # 5 left/right\n        action_list[5] = self._handle_conflict_action_index(action_dict, \"left\", \"right\", \"<null_act>\", verbose=verbose)\n        # 6 sprint/sneak\n        action_list[6] = self._handle_conflict_action_index(action_dict, \"sprint\", \"sneak\", \"<null_act>\", verbose=verbose)\n        # 7 use/attack\n        action_list[7] = self._handle_conflict_action_index(action_dict, \"use\", \"attack\", \"<null_act>\", verbose=verbose)\n        # 8 jump\n        action_list[8] = self.action_vocab[\"jump\"] if action_dict[\"jump\"] == 1 else self.action_vocab[\"<null_act>\"]\n        # 9 drop/pick\n        action_list[9] = self._handle_conflict_action_index(action_dict, \"drop\", \"pickItem\", \"<null_act>\", verbose=verbose)\n\n        if verbose:\n            print0(f\"[bold yellow]\\[MCDataset][/bold yellow] Action List: {action_list}\")\n\n        return action_list\n\n    def read_jsonl(self, jsonl_path: str):\n        assert os.path.isfile(jsonl_path), f\"[MCDataset] {jsonl_path} does not exist\"\n        # read jsonl\n        # https://github.com/openai/Video-Pre-Training/blob/main/data_loader.py#L76\n        try:\n            with open(jsonl_path) as json_file:\n                json_lines = json_file.readlines()\n                json_data = \"[\" + \",\".join(json_lines) + \"]\"\n                json_data = json.loads(json_data)\n        except Exception as e:\n            print0(f\"[bold yellow]\\[MCDataset][/bold yellow] {jsonl_path} cannot be read: {e}\")\n            return None\n        return json_data"
  },
  {
    "path": "metrics/IDM/inverse_dynamics_model.py",
    "content": "# Borrowed from VPT (https://github.com/openai/Video-Pre-Training)\n\nimport numpy as np\nimport torch as th\nimport cv2\nfrom gym3.types import DictType\nfrom gym import spaces\nfrom tqdm import tqdm\nimport os \nfrom argparse import ArgumentParser\nimport pickle\nimport cv2\nimport json\nfrom lib.action_mapping import IDMActionMapping\nfrom lib.actions import ActionTransformer\nfrom lib.policy import InverseActionPolicy\nfrom lib.torch_util import default_device_type, set_default_torch_device\n\nfrom sklearn.metrics import precision_score, recall_score, f1_score\n# Hardcoded settings\nAGENT_RESOLUTION = (128, 128)\nsafe_globals = {\"array\": np.array}\ndef resize_image(img, target_resolution):\n    # For your sanity, do not resize with any function than INTER_LINEAR\n    img = cv2.resize(img, target_resolution, interpolation=cv2.INTER_LINEAR)\n    return img\n\n\nACTION_TRANSFORMER_KWARGS = dict(\n    camera_binsize=2,\n    camera_maxval=10,\n    camera_mu=10,\n    camera_quantization_scheme=\"mu_law\",\n)\n\nclass IDMAgent:\n    \"\"\"\n    Sugarcoating on the inverse dynamics model (IDM) used to predict actions Minecraft players take in videos.\n\n    Functionally same as MineRLAgent.\n    \"\"\"\n    def __init__(self, idm_net_kwargs, pi_head_kwargs, device=None):\n        if device is None:\n            device = default_device_type()\n        self.device = th.device(device)\n        # Set the default torch device for underlying code as well\n        set_default_torch_device(self.device)\n        self.action_mapper = IDMActionMapping(n_camera_bins=11)\n        action_space = self.action_mapper.get_action_space_update()\n        action_space = DictType(**action_space)\n\n        self.action_transformer = ActionTransformer(**ACTION_TRANSFORMER_KWARGS)\n\n        idm_policy_kwargs = dict(idm_net_kwargs=idm_net_kwargs, pi_head_kwargs=pi_head_kwargs, action_space=action_space)\n\n        self.policy = InverseActionPolicy(**idm_policy_kwargs).to(device)\n        self.hidden_state = self.policy.initial_state(1)\n        self._dummy_first = th.from_numpy(np.array((False,))).to(device)\n\n    def load_weights(self, path):\n        \"\"\"Load model weights from a path, and reset hidden state\"\"\"\n        self.policy.load_state_dict(th.load(path, map_location=self.device), strict=False)\n        self.reset()\n\n    def reset(self):\n        \"\"\"Reset agent to initial state (i.e., reset hidden state)\"\"\"\n        self.hidden_state = self.policy.initial_state(1)\n\n    def _video_obs_to_agent(self, video_frames):\n        imgs = [resize_image(frame, AGENT_RESOLUTION) for frame in video_frames]\n        # Add time and batch dim\n        imgs = np.stack(imgs)[None]\n        agent_input = {\"img\": th.from_numpy(imgs).to(self.device)}\n        return agent_input\n\n    def _agent_action_to_env(self, agent_action):\n        \"\"\"Turn output from policy into action for MineRL\"\"\"\n        # This is quite important step (for some reason).\n        # For the sake of your sanity, remember to do this step (manual conversion to numpy)\n        # before proceeding. Otherwise, your agent might be a little derp.\n        action = {\n            \"buttons\": agent_action[\"buttons\"].cpu().numpy(),\n            \"camera\": agent_action[\"camera\"].cpu().numpy()\n        }\n        minerl_action = self.action_mapper.to_factored(action)\n        minerl_action_transformed = self.action_transformer.policy2env(minerl_action)\n        return minerl_action_transformed\n\n    def predict_actions(self, video_frames):\n        \"\"\"\n        Predict actions for a sequence of frames.\n\n        `video_frames` should be of shape (N, H, W, C).\n        Returns MineRL action dict, where each action head\n        has shape (N, ...).\n\n        Agent's hidden state is tracked internally. To reset it,\n        call `reset()`.\n        \"\"\"\n        agent_input = self._video_obs_to_agent(video_frames)\n        # The \"first\" argument could be used to reset tell episode\n        # boundaries, but we are only using this for predicting (for now),\n        # so we do not hassle with it yet.\n        dummy_first = th.zeros((video_frames.shape[0], 1)).to(self.device)\n        predicted_actions, self.hidden_state, _ = self.policy.predict(\n            agent_input, first=dummy_first, state_in=self.hidden_state,\n            deterministic=True\n        )\n        predicted_minerl_action = self._agent_action_to_env(predicted_actions)\n        return predicted_minerl_action\n# NOTE: this is _not_ the original code of IDM!\n# As such, while it is close and seems to function well,\n# its performance might be bit off from what is reported\n# in the paper.\n\nENV_KWARGS = dict(\n    fov_range=[70, 70],\n    frameskip=1,\n    gamma_range=[2, 2],\n    guiscale_range=[1, 1],\n    resolution=[640, 360],\n    cursor_size_range=[16.0, 16.0],\n)\n\n\nKEYBOARD_BUTTON_MAPPING = {\n    \"key.keyboard.escape\" :\"ESC\",\n    \"key.keyboard.s\" :\"back\",\n    \"key.keyboard.q\" :\"drop\",\n    \"key.keyboard.w\" :\"forward\",\n    \"key.keyboard.1\" :\"hotbar.1\",\n    \"key.keyboard.2\" :\"hotbar.2\",\n    \"key.keyboard.3\" :\"hotbar.3\",\n    \"key.keyboard.4\" :\"hotbar.4\",\n    \"key.keyboard.5\" :\"hotbar.5\",\n    \"key.keyboard.6\" :\"hotbar.6\",\n    \"key.keyboard.7\" :\"hotbar.7\",\n    \"key.keyboard.8\" :\"hotbar.8\",\n    \"key.keyboard.9\" :\"hotbar.9\",\n    \"key.keyboard.e\" :\"inventory\",\n    \"key.keyboard.space\" :\"jump\",\n    \"key.keyboard.a\" :\"left\",\n    \"key.keyboard.d\" :\"right\",\n    \"key.keyboard.left.shift\" :\"sneak\",\n    \"key.keyboard.left.control\" :\"sprint\",\n    \"key.keyboard.f\" :\"swapHands\",\n}\n\n# Template action\nNOOP_ACTION = {\n    \"ESC\": 0,\n    \"back\": 0,\n    \"drop\": 0,\n    \"forward\": 0,\n    \"hotbar.1\": 0,\n    \"hotbar.2\": 0,\n    \"hotbar.3\": 0,\n    \"hotbar.4\": 0,\n    \"hotbar.5\": 0,\n    \"hotbar.6\": 0,\n    \"hotbar.7\": 0,\n    \"hotbar.8\": 0,\n    \"hotbar.9\": 0,\n    \"inventory\": 0,\n    \"jump\": 0,\n    \"left\": 0,\n    \"right\": 0,\n    \"sneak\": 0,\n    \"sprint\": 0,\n    \"swapHands\": 0,\n    \"camera\": np.array([0, 0]),\n    \"attack\": 0,\n    \"use\": 0,\n    \"pickItem\": 0,\n}\n\n\n# Matches a number in the MineRL Java code regarding sensitivity\n# This is for mapping from recorded sensitivity to the one used in the model\nCAMERA_SCALER = 360.0 / 2400.0\n\n\ndef json_action_to_env_action(json_action):\n    \"\"\"\n    Converts a json action into a MineRL action.\n    Returns (minerl_action, is_null_action)\n    \"\"\"\n    if \"ESC\" in json_action:\n        return json_action, False\n    # This might be slow...\n    env_action = NOOP_ACTION.copy()\n    # As a safeguard, make camera action again so we do not override anything\n    env_action[\"camera\"] = np.array([0, 0])\n\n    is_null_action = True\n    keyboard_keys = json_action[\"keyboard\"][\"keys\"]\n    for key in keyboard_keys:\n        # You can have keys that we do not use, so just skip them\n        # NOTE in original training code, ESC was removed and replaced with\n        #      \"inventory\" action if GUI was open.\n        #      Not doing it here, as BASALT uses ESC to quit the game.\n        if key in KEYBOARD_BUTTON_MAPPING:\n            env_action[KEYBOARD_BUTTON_MAPPING[key]] = 1\n            is_null_action = False\n\n    mouse = json_action[\"mouse\"]\n    camera_action = env_action[\"camera\"]\n    camera_action[0] = mouse[\"dy\"] * CAMERA_SCALER\n    camera_action[1] = mouse[\"dx\"] * CAMERA_SCALER\n\n    if mouse[\"dx\"] != 0 or mouse[\"dy\"] != 0:\n        is_null_action = False\n    else:\n        if abs(camera_action[0]) > 180:\n            camera_action[0] = 0\n        if abs(camera_action[1]) > 180:\n            camera_action[1] = 0\n\n    mouse_buttons = mouse[\"buttons\"]\n    if 0 in mouse_buttons:\n        env_action[\"attack\"] = 1\n        is_null_action = False\n    if 1 in mouse_buttons:\n        env_action[\"use\"] = 1\n        is_null_action = False\n    if 2 in mouse_buttons:\n        env_action[\"pickItem\"] = 1\n        is_null_action = False\n\n    return env_action, is_null_action\n\n\ndef load_action_jsonl(json_path):\n    with open(json_path) as json_file:\n        json_lines = json_file.readlines()\n        json_data = \"[\" + \",\".join(json_lines) + \"]\"\n        json_data = json.loads(json_data)\n    return json_data \n\n\n# loss on frame - avg on video - avg on dataset \ndef evaluate_IDM_quality(model, weights,jsonl_folder, video_folder, infer_demo_num, n_frames, output_file):\n    \"\"\"\n    Evaluate the quality of a IDM model on a dataset of videos.\n\n    Args:\n        video_folder (str): Path to the folder containing videos.\n        model (str): Path to the '.model' file to be loaded.\n        weights (str): Path to the '.weights' file to be loaded.\n        n_batches (int): Number of batches to process.\n        n_frames (int): Number of frames to process at a time.\n    \"\"\"\n    ## set up IDM model \n    agent_parameters = pickle.load(open(model, \"rb\"))\n    net_kwargs = agent_parameters[\"model\"][\"args\"][\"net\"][\"args\"]\n    pi_head_kwargs = agent_parameters[\"model\"][\"args\"][\"pi_head_opts\"]\n    pi_head_kwargs[\"temperature\"] = float(pi_head_kwargs[\"temperature\"])\n    # pi_head_kwargs[\"temperature\"] = 1.0\n    agent = IDMAgent(idm_net_kwargs=net_kwargs, pi_head_kwargs=pi_head_kwargs)\n    agent.load_weights(weights)\n    # Load video files\n    video_files = os.listdir(video_folder)\n    video_files = [f for f in video_files if f.endswith(\".mp4\")]\n    video_files = sorted(video_files)\n    video_files = [os.path.join(video_folder, f) for f in video_files]\n    eval_num = min(500,len(video_files)) \n    video_files = video_files[:eval_num]\n    dataset_labels = {}\n    camera_loss_list = []\n    for video_file in tqdm(video_files):\n        json_file = os.path.join(jsonl_folder,os.path.basename(video_file).replace(\".mp4\",\".jsonl\"))\n        # old implementation\n        # action_loss,video_avg_loss,predicted_actions_list = eval_1_video(agent, video_file, json_file, infer_demo_num, n_frames) \n        \n        # load predicted actions and recorded actions\n        predicted_actions,recorded_actions = idm_prediction(agent, video_file,json_file, infer_demo_num, n_frames)\n        # construct labels\n        subtasks_labels = define_exclusive_classification_task(predicted_actions,recorded_actions,calculate_hot_bar = False)\n        for key in subtasks_labels:\n            if key not in dataset_labels:\n                dataset_labels[key] = {\"pred_labels\":[] , \"rec_labels\":[], \"class_num\":0}\n            dataset_labels[key][\"pred_labels\"].append(subtasks_labels[key][\"pred_labels\"])# array \n            dataset_labels[key][\"rec_labels\"].append(subtasks_labels[key][\"rec_labels\"]) # array \n            dataset_labels[key][\"class_num\"] = subtasks_labels[key][\"class_num\"]\n        camera_loss_list.append(camera_loss(predicted_actions,recorded_actions)[\"camera_bin_loss\"])\n        \n    dataset_results ={}\n    for key in dataset_labels:\n        pred_labels = np.stack(dataset_labels[key][\"pred_labels\"]).flatten() # [num_videos , num_frames] -> [video_num x frame_num]\n        rec_labels = np.stack(dataset_labels[key][\"rec_labels\"]).flatten()   # [num_videos , num_frames] -> [video_num x frame_num]\n        dataset_results[key]=classification_metric(pred_labels, rec_labels, dataset_labels[key][\"class_num\"])\n    \n    # import pdb;pdb.set_trace()\n    metric_mean_on_task = {}\n    metrics = ['precision_micro', 'recall_micro', 'f1_micro', 'precision_macro', 'recall_macro', 'f1_macro']\n    tasks = dataset_results.keys()\n    for key in metrics:\n        if key == \"class_num\":\n            continue\n        metric_mean_on_task[key] = np.mean([dataset_results[task][key] for task in tasks])\n    dataset_results[\"metric_mean_on_task\"] = metric_mean_on_task\n    dataset_results[\"metric_mean_on_task\"][\"camera_loss\"] = np.mean(camera_loss_list)\n    ## change all keys into str\n    dataset_results = {str(k): v for k, v in dataset_results.items()}\n        \n    print(dataset_results)\n    print(\"===========================================\")\n    print(f\"{output_file} IDM Metric: {metric_mean_on_task}\")\n    with open(output_file, 'w') as f:\n        f.write(json.dumps(dataset_results,indent=4) + \"\\n\")\n\ndef construct_classification_labels(idm_actions:dict[str, list[int]],action_name_keys: list[int],num_class:int) -> list[int]: \n    \"\"\"\n    convert original predicted actions to classification labels\n    \"\"\"\n    # construct a one-hot vector string to int label \n    vec2cls = {\"0\"*(num_class-1):0}\n    for i in range(num_class-1):\n        key = \"0\"*i + \"1\" + \"0\"*(num_class-2-i)\n        vec2cls[key] = i+1\n    # print(vec2cls)\n    vec2cls['1'*(num_class-1)] = 0 # do all equal not do \n    # vec2cls = {\"00\":0,\"10\":1,\"01\":2} # tested for class_num = 2 \n    num_labels = idm_actions[action_name_keys[0]].size # assert same length: video_num x frame_per_video\n    # if not single in first dim, we should perform flattn \n    \n    # construct one-hot vector\n    idm_action_string = [[str(int(i)) for i in idm_actions[action_name].flatten()] for action_name in action_name_keys]\n    try:\n        labels = [vec2cls[\"\".join([idm_action_string[j][i] for j in range(num_class-1)])] for i in range(num_labels)]\n    except:\n        conflicts_num = sum([ i=='1' and j=='1' for i,j in zip(idm_action_string[0],idm_action_string[1])])\n        print(f\"detect conflict prediction: {conflicts_num}\")\n        return None \n        \n    labels = np.array(labels)\n    return labels\n\ndef define_exclusive_classification_task(predicted_actions:dict,recorded_actions:dict,calculate_hot_bar = False) -> dict:\n    subtasks = {\"multi_class\":[(\"back\",\"forward\"),# 01,00,10,  \n                                (\"left\",\"right\"),\n                                (\"sneak\",\"sprint\"),\n                                ],\n                    \"binary_class\":[\"use\",\"attack\",\"jump\",\"drop\"]\n    }\n    if calculate_hot_bar:\n        subtasks[\"multi_class\"]=[(\"hotbar.1\",\"hotbar.2\",\"hotbar.3\",\"hotbar.4\",\"hotbar.5\",\"hotbar.6\",\"hotbar.7\",\"hotbar.8\",\"hotbar.9\")]\n    subtasks_labels = {}\n    for class_pair in subtasks[\"multi_class\"]:\n        class_num = len(class_pair) + 1 # len = 2 has 00 01 10 \n        # convert to strings \n        # convert to classification\n        pred_labels = construct_classification_labels(predicted_actions, class_pair, class_num)\n        rec_labels = construct_classification_labels(recorded_actions, class_pair, class_num)\n        if pred_labels is None or rec_labels is None:\n            print(f\"detect conflict prediction: {pred_labels} and {rec_labels}\")\n            continue\n        subtasks_labels[class_pair] = {\"class_num\":class_num,\n                                       \"pred_labels\":pred_labels,\n                                       \"rec_labels\":rec_labels\n                                       }\n    for binary_task in subtasks[\"binary_class\"]:\n        pred_labels =  predicted_actions[binary_task] \n        rec_labels =   recorded_actions[binary_task] \n        subtasks_labels[binary_task] = {\"class_num\":2,\n                                       \"pred_labels\":pred_labels,\n                                       \"rec_labels\":rec_labels\n                                       }\n    return subtasks_labels\n\ndef classification_metric(pred_labels, rec_labels, class_num):\n    ## compute macro and micro score for both tri classification and binary classification\n    ## the difference between macro and micro precision and binary precision for binary task is :\n    ## the binary precision only compute label with 1 ; but micro and marco compute 0, 1 and then average them \n    ## to align with tri-classification we use average=\"macro\" and average=\"micro\"\n    precision_micro = precision_score(rec_labels, pred_labels, average=\"micro\", zero_division=0)\n    recall_micro = recall_score(rec_labels, pred_labels, average=\"micro\", zero_division=0)\n    f1_micro = f1_score(rec_labels, pred_labels, average=\"micro\", zero_division=0)\n    \n    precision_macro = precision_score(rec_labels, pred_labels, average=\"macro\", zero_division=0)\n    recall_macro = recall_score(rec_labels, pred_labels, average=\"macro\", zero_division=0)\n    f1_macro = f1_score(rec_labels, pred_labels, average=\"macro\", zero_division=0)\n    return {\n        \"precision_micro\": precision_micro,\n        \"recall_micro\": recall_micro,\n        \"f1_micro\": f1_micro,\n        \"precision_macro\": precision_macro,\n        \"recall_macro\": recall_macro,\n        \"f1_macro\": f1_macro,\n        \"class_num\": class_num\n    }\n\ndef aggregate_actions(actions:list) -> dict:\n    return_dict = {}\n    for action in actions:\n        for key in action:\n            if key not in return_dict:\n                return_dict[key] = []\n            return_dict[key].append(action[key])\n    for key in return_dict:\n        return_dict[key] = np.array(return_dict[key]).reshape(-1)\n    return return_dict\n\ndef idm_prediction(agent, video_path,json_path, infer_demo_num, n_frames):\n    th.cuda.empty_cache()\n    full_json_data = load_action_jsonl(json_path)\n    json_data = full_json_data[infer_demo_num:infer_demo_num+n_frames]\n    recorded_actions = [json_action_to_env_action(i)[0] for i in json_data]\n    recorded_actions = aggregate_actions(recorded_actions)\n    frames = []\n    cap = cv2.VideoCapture(video_path)\n    for _ in range(n_frames):\n        ret, frame = cap.read()\n        if not ret:\n            print(f\"[Error] loading frames in {video_path} returing {_}\")\n            return None,None\n        # BGR -> RGB\n        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n        frames.append(frame)\n    frames = np.stack(frames)\n    predicted_actions = agent.predict_actions(frames)\n    for key in predicted_actions:\n        if key == \"camera\":\n            continue\n        predicted_actions[key] = np.array(predicted_actions[key]).reshape(-1)\n    return predicted_actions,recorded_actions\n\ndef camera_loss(predicted_actions,recorded_actions):\n    from lib.actions import CameraQuantizer\n    cam_quantizer = CameraQuantizer(\n    camera_binsize=2,\n    camera_maxval=10,\n    mu=10,\n    quantization_scheme=\"mu_law\")\n    # import pdb;pdb.set_trace()\n    cam_pred_token=cam_quantizer.discretize(predicted_actions['camera'].reshape(-1))\n    cam_gt_token  =cam_quantizer.discretize(np.array(recorded_actions['camera']))\n    camera_bin_loss = np.abs(cam_pred_token-cam_gt_token).mean()\n    return {\n        \"camera_bin_loss\":camera_bin_loss\n    }\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser(\"Evaluate IDM quality for MC-LVM \")\n    parser.add_argument(\"--weights\", type=str, required=True, help=\"[IDM model config] Path to the '.weights' file to be loaded.\")\n    parser.add_argument(\"--model\", type=str, required=True, help=\"[IDM model config] Path to the '.model' file to be loaded.\")\n    parser.add_argument(\"--jsonl-path\", type=str, required=True, help=\"[Eval Config] Path to .jsonl contains actions.\")\n    parser.add_argument(\"--video-path\", type=str, required=True, help=\"[Eval Config] Path to a .mp4 file.\")\n    parser.add_argument(\"--infer-demo-num\", type=int, default=0, help=\"[Inference Config] Number of frames to skip before starting evaluation.\")\n    parser.add_argument(\"--n-frames\", type=int, default=32, help=\"[Inference Config] Number of frames to generation.\")\n    parser.add_argument(\"--output-file\", type=str, default=\"[Eval Config] output/action_loss.jsonl\", help=\"[Eval Config] Path to save the action loss.\")\n    args = parser.parse_args()\n    os.makedirs(os.path.dirname(args.output_file), exist_ok=True)\n    evaluate_IDM_quality(args.model, args.weights,args.jsonl_path ,args.video_path, args.infer_demo_num,args.n_frames,args.output_file)\n"
  },
  {
    "path": "metrics/IDM/lib/__init__.py",
    "content": ""
  },
  {
    "path": "metrics/IDM/lib/action_head.py",
    "content": "import logging\nfrom typing import Any, Tuple\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.nn.init as init\nfrom gym3.types import DictType, Discrete, Real, TensorType, ValType\n\nLOG0 = -100\n\n\ndef fan_in_linear(module: nn.Module, scale=1.0, bias=True):\n    \"\"\"Fan-in init\"\"\"\n    module.weight.data *= scale / module.weight.norm(dim=1, p=2, keepdim=True)\n\n    if bias:\n        module.bias.data *= 0\n\n\nclass ActionHead(nn.Module):\n    \"\"\"Abstract base class for action heads compatible with forc\"\"\"\n\n    def forward(self, input_data: torch.Tensor) -> Any:\n        \"\"\"\n        Just a forward pass through this head\n        :returns pd_params - parameters describing the probability distribution\n        \"\"\"\n        raise NotImplementedError\n\n    def logprob(self, action_sample: torch.Tensor, pd_params: torch.Tensor) -> torch.Tensor:\n        \"\"\"Logartithm of probability of sampling `action_sample` from a probability described by `pd_params`\"\"\"\n        raise NotImplementedError\n\n    def entropy(self, pd_params: torch.Tensor) -> torch.Tensor:\n        \"\"\"Entropy of this distribution\"\"\"\n        raise NotImplementedError\n\n    def sample(self, pd_params: torch.Tensor, deterministic: bool = False) -> Any:\n        \"\"\"\n        Draw a sample from probability distribution given by those params\n\n        :param pd_params Parameters of a probability distribution\n        :param deterministic Whether to return a stochastic sample or deterministic mode of a distribution\n        \"\"\"\n        raise NotImplementedError\n\n    def kl_divergence(self, params_q: torch.Tensor, params_p: torch.Tensor) -> torch.Tensor:\n        \"\"\"KL divergence between two distribution described by these two params\"\"\"\n        raise NotImplementedError\n\n\nclass DiagGaussianActionHead(ActionHead):\n    \"\"\"\n    Action head where actions are normally distributed uncorrelated variables with specific means and variances.\n\n    Means are calculated directly from the network while standard deviations are a parameter of this module\n    \"\"\"\n\n    LOG2PI = np.log(2.0 * np.pi)\n\n    def __init__(self, input_dim: int, num_dimensions: int):\n        super().__init__()\n\n        self.input_dim = input_dim\n        self.num_dimensions = num_dimensions\n\n        self.linear_layer = nn.Linear(input_dim, num_dimensions)\n        self.log_std = nn.Parameter(torch.zeros(num_dimensions), requires_grad=True)\n\n    def reset_parameters(self):\n        init.orthogonal_(self.linear_layer.weight, gain=0.01)\n        init.constant_(self.linear_layer.bias, 0.0)\n\n    def forward(self, input_data: torch.Tensor, mask=None) -> torch.Tensor:\n        assert not mask, \"Can not use a mask in a gaussian action head\"\n        means = self.linear_layer(input_data)\n        # Unsqueeze many times to get to the same shape\n        logstd = self.log_std[(None,) * (len(means.shape) - 1)]\n\n        mean_view, logstd = torch.broadcast_tensors(means, logstd)\n\n        return torch.stack([mean_view, logstd], dim=-1)\n\n    def logprob(self, action_sample: torch.Tensor, pd_params: torch.Tensor) -> torch.Tensor:\n        \"\"\"Log-likelihood\"\"\"\n        means = pd_params[..., 0]\n        log_std = pd_params[..., 1]\n\n        std = torch.exp(log_std)\n\n        z_score = (action_sample - means) / std\n\n        return -(0.5 * ((z_score ** 2 + self.LOG2PI).sum(dim=-1)) + log_std.sum(dim=-1))\n\n    def entropy(self, pd_params: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Categorical distribution entropy calculation - sum probs * log(probs).\n        In case of diagonal gaussian distribution - 1/2 log(2 pi e sigma^2)\n        \"\"\"\n        log_std = pd_params[..., 1]\n        return (log_std + 0.5 * (self.LOG2PI + 1)).sum(dim=-1)\n\n    def sample(self, pd_params: torch.Tensor, deterministic: bool = False) -> torch.Tensor:\n        means = pd_params[..., 0]\n        log_std = pd_params[..., 1]\n\n        if deterministic:\n            return means\n        else:\n            return torch.randn_like(means) * torch.exp(log_std) + means\n\n    def kl_divergence(self, params_q: torch.Tensor, params_p: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Categorical distribution KL divergence calculation\n        KL(Q || P) = sum Q_i log (Q_i / P_i)\n\n        Formula is:\n        log(sigma_p) - log(sigma_q) + (sigma_q^2 + (mu_q - mu_p)^2))/(2 * sigma_p^2)\n        \"\"\"\n        means_q = params_q[..., 0]\n        log_std_q = params_q[..., 1]\n\n        means_p = params_p[..., 0]\n        log_std_p = params_p[..., 1]\n\n        std_q = torch.exp(log_std_q)\n        std_p = torch.exp(log_std_p)\n\n        kl_div = log_std_p - log_std_q + (std_q ** 2 + (means_q - means_p) ** 2) / (2.0 * std_p ** 2) - 0.5\n\n        return kl_div.sum(dim=-1, keepdim=True)\n\n\nclass CategoricalActionHead(ActionHead):\n    \"\"\"Action head with categorical actions\"\"\"\n\n    def __init__(\n        self, input_dim: int, shape: Tuple[int], num_actions: int, builtin_linear_layer: bool = True, temperature: float = 1.0\n    ):\n        super().__init__()\n\n        self.input_dim = input_dim\n        self.num_actions = num_actions\n        self.output_shape = shape + (num_actions,)\n        self.temperature = temperature\n\n        if builtin_linear_layer:\n            self.linear_layer = nn.Linear(input_dim, np.prod(self.output_shape))\n        else:\n            assert (\n                input_dim == num_actions\n            ), f\"If input_dim ({input_dim}) != num_actions ({num_actions}), you need a linear layer to convert them.\"\n            self.linear_layer = None\n\n    def reset_parameters(self):\n        if self.linear_layer is not None:\n            init.orthogonal_(self.linear_layer.weight, gain=0.01)\n            init.constant_(self.linear_layer.bias, 0.0)\n            finit.fan_in_linear(self.linear_layer, scale=0.01)\n\n    def forward(self, input_data: torch.Tensor, mask=None) -> Any:\n        if self.linear_layer is not None:\n            flat_out = self.linear_layer(input_data)\n        else:\n            flat_out = input_data\n        shaped_out = flat_out.reshape(flat_out.shape[:-1] + self.output_shape)\n        shaped_out /= self.temperature\n        if mask is not None:\n            shaped_out[~mask] = LOG0\n\n        # Convert to float32 to avoid RuntimeError: \"log_softmax_lastdim_kernel_impl\" not implemented for 'Half'\n        return F.log_softmax(shaped_out.float(), dim=-1)\n\n    def logprob(self, actions: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:\n        value = actions.long().unsqueeze(-1)\n        value, log_pmf = torch.broadcast_tensors(value, logits)\n        value = value[..., :1]\n        result = log_pmf.gather(-1, value).squeeze(-1)\n        # result is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it.\n        for _ in self.output_shape[:-1]:\n            result = result.sum(dim=-1)\n        return result\n\n    def entropy(self, logits: torch.Tensor) -> torch.Tensor:\n        \"\"\"Categorical distribution entropy calculation - sum probs * log(probs)\"\"\"\n        probs = torch.exp(logits)\n        entropy = -torch.sum(probs * logits, dim=-1)\n        # entropy is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it.\n        for _ in self.output_shape[:-1]:\n            entropy = entropy.sum(dim=-1)\n        return entropy\n\n    def sample(self, logits: torch.Tensor, deterministic: bool = False) -> Any:\n        if deterministic:\n            return torch.argmax(logits, dim=-1)\n        else:\n            # Gumbel-Softmax trick.\n            u = torch.rand_like(logits)\n            # In float16, if you have around 2^{float_mantissa_bits} logits, sometimes you'll sample 1.0\n            # Then the log(-log(1.0)) will give -inf when it should give +inf\n            # This is a silly hack to get around that.\n            # This hack does not skew the probability distribution, because this event can't possibly win the argmax.\n            u[u == 1.0] = 0.999\n\n            return torch.argmax(logits - torch.log(-torch.log(u)), dim=-1)\n\n    def kl_divergence(self, logits_q: torch.Tensor, logits_p: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Categorical distribution KL divergence calculation\n        KL(Q || P) = sum Q_i log (Q_i / P_i)\n        When talking about logits this is:\n        sum exp(Q_i) * (Q_i - P_i)\n        \"\"\"\n        kl = (torch.exp(logits_q) * (logits_q - logits_p)).sum(-1, keepdim=True)\n        # kl is per-entry, still of size self.output_shape; we need to reduce of the rest of it.\n        for _ in self.output_shape[:-1]:\n            kl = kl.sum(dim=-2)  # dim=-2 because we use keepdim=True above.\n        return kl\n\n\nclass DictActionHead(nn.ModuleDict):\n    \"\"\"Action head with multiple sub-actions\"\"\"\n\n    def reset_parameters(self):\n        for subhead in self.values():\n            subhead.reset_parameters()\n\n    def forward(self, input_data: torch.Tensor, **kwargs) -> Any:\n        \"\"\"\n        :param kwargs: each kwarg should be a dict with keys corresponding to self.keys()\n                e.g. if this ModuleDict has submodules keyed by 'A', 'B', and 'C', we could call:\n                    forward(input_data, foo={'A': True, 'C': False}, bar={'A': 7}}\n                Then children will be called with:\n                    A: forward(input_data, foo=True, bar=7)\n                    B: forward(input_data)\n                    C: forward(input_Data, foo=False)\n        \"\"\"\n        result = {}\n        for head_name, subhead in self.items():\n            head_kwargs = {\n                kwarg_name: kwarg[head_name]\n                for kwarg_name, kwarg in kwargs.items()\n                if kwarg is not None and head_name in kwarg\n            }\n            result[head_name] = subhead(input_data, **head_kwargs)\n        return result\n\n    def logprob(self, actions: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:\n        return sum(subhead.logprob(actions[k], logits[k]) for k, subhead in self.items())\n\n    def sample(self, logits: torch.Tensor, deterministic: bool = False) -> Any:\n        return {k: subhead.sample(logits[k], deterministic) for k, subhead in self.items()}\n\n    def entropy(self, logits: torch.Tensor) -> torch.Tensor:\n        return sum(subhead.entropy(logits[k]) for k, subhead in self.items())\n\n    def kl_divergence(self, logits_q: torch.Tensor, logits_p: torch.Tensor) -> torch.Tensor:\n        return sum(subhead.kl_divergence(logits_q[k], logits_p[k]) for k, subhead in self.items())\n\n\ndef make_action_head(ac_space: ValType, pi_out_size: int, temperature: float = 1.0):\n    \"\"\"Helper function to create an action head corresponding to the environment action space\"\"\"\n    if isinstance(ac_space, TensorType):\n        if isinstance(ac_space.eltype, Discrete):\n            return CategoricalActionHead(pi_out_size, ac_space.shape, ac_space.eltype.n, temperature=temperature)\n        elif isinstance(ac_space.eltype, Real):\n            if temperature != 1.0:\n                logging.warning(\"Non-1 temperature not implemented for DiagGaussianActionHead.\")\n            assert len(ac_space.shape) == 1, \"Nontrivial shapes not yet implemented.\"\n            return DiagGaussianActionHead(pi_out_size, ac_space.shape[0])\n    elif isinstance(ac_space, DictType):\n        return DictActionHead({k: make_action_head(v, pi_out_size, temperature) for k, v in ac_space.items()})\n    raise NotImplementedError(f\"Action space of type {type(ac_space)} is not supported\")\n"
  },
  {
    "path": "metrics/IDM/lib/action_mapping.py",
    "content": "import abc\nimport itertools\nfrom collections import OrderedDict\nfrom typing import Dict, List\n\nimport numpy as np\nfrom gym3.types import DictType, Discrete, TensorType\n\nfrom lib.actions import Buttons\n\n\nclass ActionMapping(abc.ABC):\n    \"\"\"Class that maps between the standard MC factored action space and a new one you define!\n\n    :param n_camera_bins: Need to specify this to define the original ac space for stats code\n    \"\"\"\n\n    # This is the default buttons groups, it can be changed for your action space\n    BUTTONS_GROUPS = OrderedDict(\n        hotbar=[\"none\"] + [f\"hotbar.{i}\" for i in range(1, 10)],\n        fore_back=[\"none\", \"forward\", \"back\"],\n        left_right=[\"none\", \"left\", \"right\"],\n        sprint_sneak=[\"none\", \"sprint\", \"sneak\"],\n        use=[\"none\", \"use\"],\n        drop=[\"none\", \"drop\"],\n        attack=[\"none\", \"attack\"],\n        jump=[\"none\", \"jump\"],\n    )\n\n    def __init__(self, n_camera_bins: int = 11):\n        assert n_camera_bins % 2 == 1, \"n_camera_bins should be odd\"\n        self.n_camera_bins = n_camera_bins\n        self.camera_null_bin = n_camera_bins // 2\n        self.stats_ac_space = DictType(\n            **{\n                \"buttons\": TensorType(shape=(len(Buttons.ALL),), eltype=Discrete(2)),\n                \"camera\": TensorType(shape=(2,), eltype=Discrete(n_camera_bins)),\n            }\n        )\n\n    @abc.abstractmethod\n    def from_factored(self, ac: Dict) -> Dict:\n        \"\"\"Converts a factored action (ac) to the new space\n\n        :param ac: Dictionary of actions that must have a batch dimension\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def to_factored(self, ac: Dict) -> Dict:\n        \"\"\"Converts an action in the new space (ac) to the factored action space.\n\n        :param ac: Dictionary of actions that must have a batch dimension\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def get_action_space_update(self):\n        \"\"\"Return a magym (gym3) action space. This will be used to update the env action space.\"\"\"\n        pass\n\n    @abc.abstractmethod\n    def get_zero_action(self):\n        \"\"\"Return the zero or null action for this action space\"\"\"\n        pass\n\n    def factored_buttons_to_groups(self, ac_buttons: np.ndarray, button_group: List[str]) -> List[str]:\n        \"\"\"For a mutually exclusive group of buttons in button_group, find which option\n        in the group was chosen. Assumes that each button group has the option of 'none'\n        meaning that no button in the group was pressed.\n\n        :param ac_buttons: button actions from the factored action space. Should dims [B, len(Buttons.ALL)]\n        :param button_group: List of buttons in a mutually exclusive group. Each item in the\n            list should appear in Buttons.ALL except for the special case 'none' which means\n            no button in the group was pressed. e.g. ['none', 'forward', 'back']. For now\n            'none' must be the first element of button_group\n\n        Returns a list of length B, where each element is an item from button_group.\n        \"\"\"\n        assert ac_buttons.shape[1] == len(\n            Buttons.ALL\n        ), f\"There should be {len(Buttons.ALL)} buttons in the factored buttons space\"\n        assert button_group[0] == \"none\", \"This function only works if 'none' is in button_group\"\n        # Actions in ac_buttons with order according to button_group\n        group_indices = [Buttons.ALL.index(b) for b in button_group if b != \"none\"]\n        ac_choices = ac_buttons[:, group_indices]\n\n        # Special cases for forward/back, left/right where mutual press means do neither\n        if \"forward\" in button_group and \"back\" in button_group:\n            ac_choices[np.all(ac_choices, axis=-1)] = 0\n        if \"left\" in button_group and \"right\" in button_group:\n            ac_choices[np.all(ac_choices, axis=-1)] = 0\n        ac_non_zero = np.where(ac_choices)\n        ac_choice = [\"none\" for _ in range(ac_buttons.shape[0])]\n        # Iterate over the non-zero indices so that if two buttons in a group were pressed at the same time\n        # we give priority to the button later in the group. E.g. if hotbar.1 and hotbar.2 are pressed during the same\n        # timestep, hotbar.2 is marked as pressed\n        for index, action in zip(ac_non_zero[0], ac_non_zero[1]):\n            ac_choice[index] = button_group[action + 1]  # the zero'th index will mean no button pressed\n        return ac_choice\n\nclass IDMActionMapping(ActionMapping):\n    \"\"\"For IDM, but essentially this is just an identity mapping\"\"\"\n    def from_factored(self, ac: Dict) -> Dict:\n        return ac\n\n    def to_factored(self, ac: Dict) -> Dict:\n        return ac\n\n    def get_action_space_update(self):\n        \"\"\"Return a magym (gym3) action space. This will be used to update the env action space.\"\"\"\n        return {\n            \"buttons\": TensorType(shape=(len(Buttons.ALL),), eltype=Discrete(2)),\n            \"camera\": TensorType(shape=(2,), eltype=Discrete(self.n_camera_bins)),\n        }\n\n    def get_zero_action(self):\n        raise NotImplementedError()\n\nclass CameraHierarchicalMapping(ActionMapping):\n    \"\"\"Buttons are joint as in ButtonsJointMapping, but now a camera on/off meta action is added into this joint space.\n    When this meta action is triggered, the separate camera head chooses a camera action which is also now a joint space.\n\n    :param n_camera_bins: number of camera bins in the factored space\n    \"\"\"\n\n    # Add camera meta action to BUTTONS_GROUPS\n    BUTTONS_GROUPS = ActionMapping.BUTTONS_GROUPS.copy()\n    BUTTONS_GROUPS[\"camera\"] = [\"none\", \"camera\"]\n    BUTTONS_COMBINATIONS = list(itertools.product(*BUTTONS_GROUPS.values())) + [\"inventory\"]\n    BUTTONS_COMBINATION_TO_IDX = {comb: i for i, comb in enumerate(BUTTONS_COMBINATIONS)}\n    BUTTONS_IDX_TO_COMBINATION = {i: comb for i, comb in enumerate(BUTTONS_COMBINATIONS)}\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.camera_groups = OrderedDict(\n            camera_x=[f\"camera_x{i}\" for i in range(self.n_camera_bins)],\n            camera_y=[f\"camera_y{i}\" for i in range(self.n_camera_bins)],\n        )\n        self.camera_combinations = list(itertools.product(*self.camera_groups.values()))\n        self.camera_combination_to_idx = {comb: i for i, comb in enumerate(self.camera_combinations)}\n        self.camera_idx_to_combination = {i: comb for i, comb in enumerate(self.camera_combinations)}\n        self.camera_null_idx = self.camera_combination_to_idx[\n            (f\"camera_x{self.camera_null_bin}\", f\"camera_y{self.camera_null_bin}\")\n        ]\n        self._null_action = {\n            \"buttons\": self.BUTTONS_COMBINATION_TO_IDX[tuple(\"none\" for _ in range(len(self.BUTTONS_GROUPS)))]\n        }\n        self._precompute_to_factored()\n\n    def _precompute_to_factored(self):\n        \"\"\"Precompute the joint action -> factored action matrix.\"\"\"\n        button_dim = self.stats_ac_space[\"buttons\"].size\n        self.BUTTON_IDX_TO_FACTORED = np.zeros((len(self.BUTTONS_IDX_TO_COMBINATION), button_dim), dtype=int)\n        self.BUTTON_IDX_TO_CAMERA_META_OFF = np.zeros((len(self.BUTTONS_IDX_TO_COMBINATION)), dtype=bool)\n        self.CAMERA_IDX_TO_FACTORED = np.zeros((len(self.camera_idx_to_combination), 2), dtype=int)\n\n        # Pre compute Buttons\n        for jnt_ac, button_comb in self.BUTTONS_IDX_TO_COMBINATION.items():\n            new_button_ac = np.zeros(len(Buttons.ALL), dtype=\"i\")\n            if button_comb == \"inventory\":\n                new_button_ac[Buttons.ALL.index(\"inventory\")] = 1\n            else:\n                for group_choice in button_comb[:-1]:  # Last one is camera\n                    if group_choice != \"none\":\n                        new_button_ac[Buttons.ALL.index(group_choice)] = 1\n\n                if button_comb[-1] != \"camera\":  # This means camera meta action is off\n                    self.BUTTON_IDX_TO_CAMERA_META_OFF[jnt_ac] = True\n            self.BUTTON_IDX_TO_FACTORED[jnt_ac] = new_button_ac\n\n        # Pre compute camera\n        for jnt_ac, camera_comb in self.camera_idx_to_combination.items():\n            new_camera_ac = np.ones((2), dtype=\"i\") * self.camera_null_bin\n            new_camera_ac[0] = self.camera_groups[\"camera_x\"].index(camera_comb[0])\n            new_camera_ac[1] = self.camera_groups[\"camera_y\"].index(camera_comb[1])\n            self.CAMERA_IDX_TO_FACTORED[jnt_ac] = new_camera_ac\n\n    def from_factored(self, ac: Dict) -> Dict:\n        \"\"\"Converts a factored action (ac) to the new space. Assumes ac has a batch dim\"\"\"\n        assert ac[\"camera\"].ndim == 2, f\"bad camera label, {ac['camera']}\"\n        assert ac[\"buttons\"].ndim == 2, f\"bad buttons label, {ac['buttons']}\"\n        # Get button choices for everything but camera\n        choices_by_group = OrderedDict(\n            (k, self.factored_buttons_to_groups(ac[\"buttons\"], v)) for k, v in self.BUTTONS_GROUPS.items() if k != \"camera\"\n        )\n        # Set camera \"on off\" action based on whether non-null camera action was given\n        camera_is_null = np.all(ac[\"camera\"] == self.camera_null_bin, axis=1)\n        choices_by_group[\"camera\"] = [\"none\" if is_null else \"camera\" for is_null in camera_is_null]\n\n        new_button_ac = []\n        new_camera_ac = []\n        for i in range(ac[\"buttons\"].shape[0]):\n            # Buttons\n            key = tuple([v[i] for v in choices_by_group.values()])\n            if ac[\"buttons\"][i, Buttons.ALL.index(\"inventory\")] == 1:\n                key = \"inventory\"\n            new_button_ac.append(self.BUTTONS_COMBINATION_TO_IDX[key])\n\n            # Camera -- inventory is also exclusive with camera\n            if key == \"inventory\":\n                key = (\n                    f\"camera_x{self.camera_null_bin}\",\n                    f\"camera_y{self.camera_null_bin}\",\n                )\n            else:\n                key = (f\"camera_x{ac['camera'][i][0]}\", f\"camera_y{ac['camera'][i][1]}\")\n            new_camera_ac.append(self.camera_combination_to_idx[key])\n\n        return dict(\n            buttons=np.array(new_button_ac)[:, None],\n            camera=np.array(new_camera_ac)[:, None],\n        )\n\n    def to_factored(self, ac: Dict) -> Dict:\n        \"\"\"Converts an action in the new space (ac) to the factored action space. Assumes ac has a batch dim\"\"\"\n        assert ac[\"camera\"].shape[-1] == 1\n        assert ac[\"buttons\"].shape[-1] == 1\n\n        new_button_ac = self.BUTTON_IDX_TO_FACTORED[np.squeeze(ac[\"buttons\"], -1)]\n        camera_off = self.BUTTON_IDX_TO_CAMERA_META_OFF[np.squeeze(ac[\"buttons\"], -1)]\n        new_camera_ac = self.CAMERA_IDX_TO_FACTORED[np.squeeze(ac[\"camera\"], -1)]\n        new_camera_ac[camera_off] = self.camera_null_bin\n\n        return dict(buttons=new_button_ac, camera=new_camera_ac)\n\n    def get_action_space_update(self):\n        return {\n            \"camera\": TensorType(shape=(1,), eltype=Discrete(len(self.camera_combinations))),\n            \"buttons\": TensorType(shape=(1,), eltype=Discrete(len(self.BUTTONS_COMBINATIONS))),\n        }\n\n    def get_zero_action(self):\n        return self._null_action\n\n"
  },
  {
    "path": "metrics/IDM/lib/actions.py",
    "content": "import attr\n# import minerl.herobraine.hero.mc as mc\nimport numpy as np\n\nfrom lib.minecraft_util import store_args\n\n\nclass Buttons:\n    ATTACK = \"attack\"\n    BACK = \"back\"\n    FORWARD = \"forward\"\n    JUMP = \"jump\"\n    LEFT = \"left\"\n    RIGHT = \"right\"\n    SNEAK = \"sneak\"\n    SPRINT = \"sprint\"\n    USE = \"use\"\n    DROP = \"drop\"\n    INVENTORY = \"inventory\"\n\n    ALL = [\n        ATTACK,\n        BACK,\n        FORWARD,\n        JUMP,\n        LEFT,\n        RIGHT,\n        SNEAK,\n        SPRINT,\n        USE,\n        DROP,\n        INVENTORY,\n    ] + [f\"hotbar.{i}\" for i in range(1, 10)]\n\n\nclass SyntheticButtons:\n    # Composite / scripted actions\n    CHANNEL_ATTACK = \"channel-attack\"\n\n    ALL = [CHANNEL_ATTACK]\n\n\nclass QuantizationScheme:\n    LINEAR = \"linear\"\n    MU_LAW = \"mu_law\"\n\n\n@attr.s(auto_attribs=True)\nclass CameraQuantizer:\n    \"\"\"\n    A camera quantizer that discretizes and undiscretizes a continuous camera input with y (pitch) and x (yaw) components.\n\n    Parameters:\n    - camera_binsize: The size of the bins used for quantization. In case of mu-law quantization, it corresponds to the average binsize.\n    - camera_maxval: The maximum value of the camera action.\n    - quantization_scheme: The quantization scheme to use. Currently, two quantization schemes are supported:\n    - Linear quantization (default): Camera actions are split uniformly into discrete bins\n    - Mu-law quantization: Transforms the camera action using mu-law encoding (https://en.wikipedia.org/wiki/%CE%9C-law_algorithm)\n    followed by the same quantization scheme used by the linear scheme.\n    - mu: Mu is the parameter that defines the curvature of the mu-law encoding. Higher values of\n    mu will result in a sharper transition near zero. Below are some reference values listed\n    for choosing mu given a constant maxval and a desired max_precision value.\n    maxval = 10 | max_precision = 0.5  | μ ≈ 2.93826\n    maxval = 10 | max_precision = 0.4  | μ ≈ 4.80939\n    maxval = 10 | max_precision = 0.25 | μ ≈ 11.4887\n    maxval = 20 | max_precision = 0.5  | μ ≈ 2.7\n    maxval = 20 | max_precision = 0.4  | μ ≈ 4.39768\n    maxval = 20 | max_precision = 0.25 | μ ≈ 10.3194\n    maxval = 40 | max_precision = 0.5  | μ ≈ 2.60780\n    maxval = 40 | max_precision = 0.4  | μ ≈ 4.21554\n    maxval = 40 | max_precision = 0.25 | μ ≈ 9.81152\n    \"\"\"\n\n    camera_maxval: int\n    camera_binsize: int\n    quantization_scheme: str = attr.ib(\n        default=QuantizationScheme.LINEAR,\n        validator=attr.validators.in_([QuantizationScheme.LINEAR, QuantizationScheme.MU_LAW]),\n    )\n    mu: float = attr.ib(default=5)\n\n    def discretize(self, xy):\n        xy = np.clip(xy, -self.camera_maxval, self.camera_maxval)\n\n        if self.quantization_scheme == QuantizationScheme.MU_LAW:\n            xy = xy / self.camera_maxval\n            v_encode = np.sign(xy) * (np.log(1.0 + self.mu * np.abs(xy)) / np.log(1.0 + self.mu))\n            v_encode *= self.camera_maxval\n            xy = v_encode\n\n        # Quantize using linear scheme\n        return np.round((xy + self.camera_maxval) / self.camera_binsize).astype(np.int64)\n\n    def undiscretize(self, xy):\n        xy = xy * self.camera_binsize - self.camera_maxval\n\n        if self.quantization_scheme == QuantizationScheme.MU_LAW:\n            xy = xy / self.camera_maxval\n            v_decode = np.sign(xy) * (1.0 / self.mu) * ((1.0 + self.mu) ** np.abs(xy) - 1.0)\n            v_decode *= self.camera_maxval\n            xy = v_decode\n        return xy\n\n\nclass ActionTransformer:\n    \"\"\"Transforms actions between internal array and minerl env format.\"\"\"\n\n    @store_args\n    def __init__(\n        self,\n        camera_maxval=10,\n        camera_binsize=2,\n        camera_quantization_scheme=\"linear\",\n        camera_mu=5,\n    ):\n        self.quantizer = CameraQuantizer(\n            camera_maxval=camera_maxval,\n            camera_binsize=camera_binsize,\n            quantization_scheme=camera_quantization_scheme,\n            mu=camera_mu,\n        )\n\n    def camera_zero_bin(self):\n        return self.camera_maxval // self.camera_binsize\n\n    def discretize_camera(self, xy):\n        return self.quantizer.discretize(xy)\n\n    def undiscretize_camera(self, pq):\n        return self.quantizer.undiscretize(pq)\n\n    def item_embed_id_to_name(self, item_id):\n        return mc.MINERL_ITEM_MAP[item_id]\n\n    def dict_to_numpy(self, acs):\n        \"\"\"\n        Env format to policy output format.\n        \"\"\"\n        act = {\n            \"buttons\": np.stack([acs.get(k, 0) for k in Buttons.ALL], axis=-1),\n            \"camera\": self.discretize_camera(acs[\"camera\"]),\n        }\n        if not self.human_spaces:\n            act.update(\n                {\n                    \"synthetic_buttons\": np.stack([acs[k] for k in SyntheticButtons.ALL], axis=-1),\n                    \"place\": self.item_embed_name_to_id(acs[\"place\"]),\n                    \"equip\": self.item_embed_name_to_id(acs[\"equip\"]),\n                    \"craft\": self.item_embed_name_to_id(acs[\"craft\"]),\n                }\n            )\n        return act\n\n    def numpy_to_dict(self, acs):\n        \"\"\"\n        Numpy policy output to env-compatible format.\n        \"\"\"\n        assert acs[\"buttons\"].shape[-1] == len(\n            Buttons.ALL\n        ), f\"Mismatched actions: {acs}; expected {len(Buttons.ALL)}:\\n(  {Buttons.ALL})\"\n        out = {name: acs[\"buttons\"][..., i] for (i, name) in enumerate(Buttons.ALL)}\n\n        out[\"camera\"] = self.undiscretize_camera(acs[\"camera\"])\n\n        return out\n\n    def policy2env(self, acs):\n        acs = self.numpy_to_dict(acs)\n        return acs\n\n    def env2policy(self, acs):\n        nbatch = acs[\"camera\"].shape[0]\n        dummy = np.zeros((nbatch,))\n        out = {\n            \"camera\": self.discretize_camera(acs[\"camera\"]),\n            \"buttons\": np.stack([acs.get(k, dummy) for k in Buttons.ALL], axis=-1),\n        }\n        return out\n"
  },
  {
    "path": "metrics/IDM/lib/impala_cnn.py",
    "content": "import math\nfrom copy import deepcopy\nfrom typing import Dict, List, Optional\n\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom lib import misc\nfrom lib import torch_util as tu\nfrom lib.util import FanInInitReLULayer\n\n\nclass CnnBasicBlock(nn.Module):\n    \"\"\"\n    Residual basic block, as in ImpalaCNN. Preserves channel number and shape\n    :param inchan: number of input channels\n    :param init_scale: weight init scale multiplier\n    \"\"\"\n\n    def __init__(\n        self,\n        inchan: int,\n        init_scale: float = 1,\n        log_scope=\"\",\n        init_norm_kwargs: Dict = {},\n        **kwargs,\n    ):\n        super().__init__()\n        self.inchan = inchan\n        s = math.sqrt(init_scale)\n        self.conv0 = FanInInitReLULayer(\n            self.inchan,\n            self.inchan,\n            kernel_size=3,\n            padding=1,\n            init_scale=s,\n            log_scope=f\"{log_scope}/conv0\",\n            **init_norm_kwargs,\n        )\n        self.conv1 = FanInInitReLULayer(\n            self.inchan,\n            self.inchan,\n            kernel_size=3,\n            padding=1,\n            init_scale=s,\n            log_scope=f\"{log_scope}/conv1\",\n            **init_norm_kwargs,\n        )\n\n    def forward(self, x):\n        x = x + self.conv1(self.conv0(x))\n        return x\n\n\nclass CnnDownStack(nn.Module):\n    \"\"\"\n    Downsampling stack from Impala CNN.\n    :param inchan: number of input channels\n    :param nblock: number of residual blocks after downsampling\n    :param outchan: number of output channels\n    :param init_scale: weight init scale multiplier\n    :param pool: if true, downsample with max pool\n    :param post_pool_groups: if not None, normalize with group norm with this many groups\n    :param kwargs: remaining kwargs are passed into the blocks and layers\n    \"\"\"\n\n    name = \"Impala_CnnDownStack\"\n\n    def __init__(\n        self,\n        inchan: int,\n        nblock: int,\n        outchan: int,\n        init_scale: float = 1,\n        pool: bool = True,\n        post_pool_groups: Optional[int] = None,\n        log_scope: str = \"\",\n        init_norm_kwargs: Dict = {},\n        first_conv_norm=False,\n        **kwargs,\n    ):\n        super().__init__()\n        self.inchan = inchan\n        self.outchan = outchan\n        self.pool = pool\n        first_conv_init_kwargs = deepcopy(init_norm_kwargs)\n        if not first_conv_norm:\n            first_conv_init_kwargs[\"group_norm_groups\"] = None\n            first_conv_init_kwargs[\"batch_norm\"] = False\n        self.firstconv = FanInInitReLULayer(\n            inchan,\n            outchan,\n            kernel_size=3,\n            padding=1,\n            log_scope=f\"{log_scope}/firstconv\",\n            **first_conv_init_kwargs,\n        )\n        self.post_pool_groups = post_pool_groups\n        if post_pool_groups is not None:\n            self.n = nn.GroupNorm(post_pool_groups, outchan)\n        self.blocks = nn.ModuleList(\n            [\n                CnnBasicBlock(\n                    outchan,\n                    init_scale=init_scale / math.sqrt(nblock),\n                    log_scope=f\"{log_scope}/block{i}\",\n                    init_norm_kwargs=init_norm_kwargs,\n                    **kwargs,\n                )\n                for i in range(nblock)\n            ]\n        )\n\n    def forward(self, x):\n        x = self.firstconv(x)\n        if self.pool:\n            x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)\n            if self.post_pool_groups is not None:\n                x = self.n(x)\n        x = tu.sequential(self.blocks, x, diag_name=self.name)\n        return x\n\n    def output_shape(self, inshape):\n        c, h, w = inshape\n        assert c == self.inchan\n        if self.pool:\n            return (self.outchan, (h + 1) // 2, (w + 1) // 2)\n        else:\n            return (self.outchan, h, w)\n\n\nclass ImpalaCNN(nn.Module):\n    \"\"\"\n    :param inshape: input image shape (height, width, channels)\n    :param chans: number of residual downsample stacks. Each element is the number of\n        filters per convolution in the stack\n    :param outsize: output hidden size\n    :param nblock: number of residual blocks per stack. Each block has 2 convs and a residual\n    :param init_norm_kwargs: arguments to be passed to convolutional layers. Options can be found\n        in ypt.model.util:FanInInitReLULayer\n    :param dense_init_norm_kwargs: arguments to be passed to convolutional layers. Options can be found\n        in ypt.model.util:FanInInitReLULayer\n    :param kwargs: remaining kwargs are passed into the CnnDownStacks\n    \"\"\"\n\n    name = \"ImpalaCNN\"\n\n    def __init__(\n        self,\n        inshape: List[int],\n        chans: List[int],\n        outsize: int,\n        nblock: int,\n        init_norm_kwargs: Dict = {},\n        dense_init_norm_kwargs: Dict = {},\n        first_conv_norm=False,\n        **kwargs,\n    ):\n        super().__init__()\n        h, w, c = inshape\n        curshape = (c, h, w)\n        self.stacks = nn.ModuleList()\n        for i, outchan in enumerate(chans):\n            stack = CnnDownStack(\n                curshape[0],\n                nblock=nblock,\n                outchan=outchan,\n                init_scale=math.sqrt(len(chans)),\n                log_scope=f\"downstack{i}\",\n                init_norm_kwargs=init_norm_kwargs,\n                first_conv_norm=first_conv_norm if i == 0 else True,\n                **kwargs,\n            )\n            self.stacks.append(stack)\n            curshape = stack.output_shape(curshape)\n\n        self.dense = FanInInitReLULayer(\n            misc.intprod(curshape),\n            outsize,\n            layer_type=\"linear\",\n            log_scope=\"imapala_final_dense\",\n            init_scale=1.4,\n            **dense_init_norm_kwargs,\n        )\n        self.outsize = outsize\n\n    def forward(self, x):\n        b, t = x.shape[:-3]\n        x = x.reshape(b * t, *x.shape[-3:])\n        x = misc.transpose(x, \"bhwc\", \"bchw\")\n        x = tu.sequential(self.stacks, x, diag_name=self.name)\n        x = x.reshape(b, t, *x.shape[1:])\n        x = tu.flatten_image(x)\n        x = self.dense(x)\n        return x\n"
  },
  {
    "path": "metrics/IDM/lib/masked_attention.py",
    "content": "import functools\n\nimport torch as th\nfrom torch import nn\n\nimport lib.xf as xf\nfrom lib.minecraft_util import store_args\nfrom lib.tree_util import tree_map\n\n\n@functools.lru_cache()\ndef get_band_diagonal_mask(t: int, T: int, maxlen: int, batchsize: int, device: th.device) -> th.Tensor:\n    \"\"\"Returns a band diagonal mask which is causal (upper triangle is masked)\n    and such that any frame can only view up to maxlen total past frames\n    including the current frame.\n\n    Example Masks: Here 0 means that frame is masked and we mask it by adding a huge number to the attention logits (see orc.xf)\n        t = 3, T = 3, maxlen = 3\n          T\n        t 1 0 0 |  mask out T > t\n          1 1 0 |\n          1 1 1 |\n        t = 3, T = 6, maxlen = 3\n        t 0 1 1 1 0 0 |  mask out T > t\n          0 0 1 1 1 0 |\n          0 0 0 1 1 1 |\n\n    Args:\n        t: number of rows (presumably number of frames recieving gradient)\n        T: number of cols (presumably t + past context that isn't being gradient updated)\n        maxlen: maximum number of frames (including current frame) any frame can attend to\n        batchsize: number of masks to return\n        device: torch device to place mask on\n\n    Returns:\n        Boolean mask of shape (batchsize, t, T)\n    \"\"\"\n    m = th.ones(t, T, dtype=bool)\n    m.tril_(T - t)  # Mask out upper triangle\n    if maxlen is not None and maxlen < T:  # Mask out lower triangle\n        m.triu_(T - t - maxlen + 1)\n    m_btT = m[None].repeat_interleave(batchsize, dim=0)\n    m_btT = m_btT.to(device=device)\n    return m_btT\n\n\ndef get_mask(first_b11: th.Tensor, state_mask: th.Tensor, t: int, T: int, maxlen: int, heads: int, device) -> th.Tensor:\n    \"\"\"Returns a band diagonal mask that respects masking past states (columns 0:T-t inclusive)\n        if first_b11 is True. See get_band_diagonal_mask for how the base mask is computed.\n        This function takes that mask and first zeros out any past context if first_b11 is True.\n\n        Say our context is in chunks of length t (so here T = 4t). We see that in the second batch we recieved first=True\n        context     t t t t\n        first       F T F F\n        Now, given this the mask should mask out anything prior to T < t; however since we don't have access to the past first_b11's\n        we need to keep a state of the mask at those past timesteps. This is what state_mask is.\n\n        In particular state_mask is a [b, t, T - t] mask matrix that contains the mask for the past T - t frames.\n\n    Args: (See get_band_diagonal_mask for remaining args)\n        first_b11: boolean tensor with shape [batchsize, 1, 1] indicating if the first timestep for each batch element had first=True\n        state_mask: mask tensor of shape [b, t, T - t]\n        t: number of mask rows (presumably number of frames for which we take gradient)\n        T: number of mask columns (t + the number of past frames we keep in context)\n        maxlen: actual context length\n        heads: number of attention heads\n        device: torch device\n\n    Returns:\n        m_btT: Boolean mask of shape (batchsize * heads, t, T)\n        state_mask: updated state_mask\n    \"\"\"\n    b = first_b11.shape[0]\n\n    if state_mask is None:\n        state_mask = th.zeros((b, 1, T - t), dtype=bool, device=device)\n\n    m_btT = get_band_diagonal_mask(t, T, maxlen, b, device).clone()  # Should be shape B, t, T\n    not_first = ~first_b11.to(device=device)\n    m_btT[:, :, :-t] &= not_first  # Zero out anything in the past if first is true\n    m_btT[:, :, :-t] &= state_mask\n    m_bhtT = m_btT[:, None].repeat_interleave(heads, dim=1)\n    m_btT = m_bhtT.reshape((b * heads), t, T)\n\n    # Update state_mask such that it reflects the most recent first\n    state_mask = th.cat(\n        [\n            state_mask[:, :, t:] & not_first,\n            th.ones((b, 1, min(t, T - t)), dtype=bool, device=device),\n        ],\n        dim=-1,\n    )\n\n    return m_btT, state_mask\n\n\nclass MaskedAttention(nn.Module):\n    \"\"\"\n    Transformer self-attention layer that removes frames from previous episodes from the hidden state under certain constraints.\n\n    The constraints are:\n    - The \"first\" flag can only be true for the first timestep of each batch. An assert will fire if other timesteps have first = True.\n\n    input_size: The dimension of the input (which also happens to be the size of the output)\n    memory_size: The number of frames to keep in the inner state. Note that when attending, we will be able to attend\n                 to both the frames in the inner state (which presumably won't have gradients anymore) and the frames\n                 in the batch. \"mask\" for some additional considerations on this.\n    heads: The number of attention heads to use. Note that we will split the input into this number of heads, so\n           input_size needs to be divisible by heads.\n    timesteps: number of timesteps with which we'll be taking gradient\n    mask: Can be \"none\" or \"clipped_causal\". \"clipped_causal\" is a normal causal mask but solves the following minor problem:\n        if you have a state of length 128 and a batch of 128 frames, then the first frame of your batch will be able to\n        attend to 128 previous frames, but the last one will be able to attend to 255 previous frames. In this example,\n        \"clipped_causal\" will make it so that the last frame can only attend to 128 previous frames, so that there is no\n        bias coming from the position in the batch. None simply allows you to attend to any frame in the state + batch,\n        which means you can also attend to future frames.\n    \"\"\"\n\n    @store_args\n    def __init__(\n        self,\n        input_size,\n        memory_size: int,\n        heads: int,\n        timesteps: int,\n        mask: str = \"clipped_causal\",\n        init_scale=1,\n        norm=\"none\",\n        log_scope=\"sa\",\n        use_muP_factor=False,\n    ):\n        super().__init__()\n\n        assert mask in {\"none\", \"clipped_causal\"}\n        assert memory_size >= 0\n\n        self.maxlen = memory_size - timesteps\n        if mask == \"none\":\n            mask = None\n\n        self.orc_attn = xf.All2All(heads, self.maxlen, mask=mask is not None)\n        self.orc_block = xf.SelfAttentionLayer(\n            input_size,\n            self.orc_attn,\n            scale=init_scale,\n            relattn=True,\n            cache_keep_len=self.maxlen,\n            norm=norm,\n            log_scope=log_scope,\n            use_muP_factor=use_muP_factor,\n        )\n\n    def initial_state(self, batchsize: int, device=None):\n        \"\"\"Return the initial state mask (None) and the initial state of the transformer (zerod out keys and queries)\"\"\"\n        state = self.orc_block.initial_state(batchsize, initial_T=self.maxlen)\n        state_mask = None\n        if device is not None:\n            state = tree_map(lambda x: x.to(device), state)\n        return state_mask, state\n\n    def forward(self, input_bte, first_bt, state):\n        \"\"\"Forward propagation of a single layer\"\"\"\n        state_mask, xf_state = state\n        t = first_bt.shape[1]\n        if self.mask == \"clipped_causal\":\n            new_mask, state_mask = get_mask(\n                first_b11=first_bt[:, [[0]]],\n                state_mask=state_mask,\n                t=t,\n                T=t + self.maxlen,\n                maxlen=self.maxlen,\n                heads=self.heads,\n                device=input_bte.device,\n            )\n            self.orc_block.attn.mask = new_mask\n        output, xf_state = self.orc_block(input_bte, xf_state)\n\n        return output, (state_mask, xf_state)\n\n    def get_log_keys(self):\n        # These are logged in xf.SelfAttentionLayer\n        return [f\"activation_{stat}/{self.log_scope}/{k}\" for k in [\"K\", \"Q\", \"V\", \"A\", \"Aproj\"] for stat in [\"mean\", \"std\"]]\n"
  },
  {
    "path": "metrics/IDM/lib/minecraft_util.py",
    "content": "import functools\nimport inspect\nfrom typing import Optional, Tuple\n\nimport numpy as np\nimport torch\n\nfrom lib.action_head import (CategoricalActionHead, DiagGaussianActionHead,\n                             DictActionHead)\n\n\ndef store_args(method):\n    \"\"\"Stores provided method args as instance attributes.\"\"\"\n    argspec = inspect.getfullargspec(method)\n    defaults = {}\n    if argspec.defaults is not None:\n        defaults = dict(zip(argspec.args[-len(argspec.defaults) :], argspec.defaults))\n    if argspec.kwonlydefaults is not None:\n        defaults.update(argspec.kwonlydefaults)\n    arg_names = argspec.args[1:]\n\n    @functools.wraps(method)\n    def wrapper(*positional_args, **keyword_args):\n        self = positional_args[0]\n        # Get default arg values\n        args = defaults.copy()\n        # Add provided arg values\n        for name, value in zip(arg_names, positional_args[1:]):\n            args[name] = value\n        args.update(keyword_args)\n        self.__dict__.update(args)\n        return method(*positional_args, **keyword_args)\n\n    return wrapper\n\n\ndef get_norm_entropy_from_cat_head(module, name, masks, logits):\n    # Note that the mask has already been applied to the logits at this point\n    entropy = -torch.sum(torch.exp(logits) * logits, dim=-1)\n    if name in masks:\n        n = torch.sum(masks[name], dim=-1, dtype=torch.float)\n        norm_entropy = entropy / torch.log(n)\n        # When the mask only allows one option the normalized entropy makes no sense\n        # as it is basically both maximal (the distribution is as uniform as it can be)\n        # and minimal (there is no variance at all).\n        # A such, we ignore them for purpose of calculating entropy.\n        zero = torch.zeros_like(norm_entropy)\n        norm_entropy = torch.where(n.eq(1.0), zero, norm_entropy)\n        count = n.not_equal(1.0).int()\n    else:\n        n = torch.tensor(logits.shape[-1], dtype=torch.float)\n        norm_entropy = entropy / torch.log(n)\n        count = torch.ones_like(norm_entropy, dtype=torch.int)\n\n    # entropy is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it.\n    for _ in module.output_shape[:-1]:\n        norm_entropy = norm_entropy.sum(dim=-1)\n        count = count.sum(dim=-1)\n    return norm_entropy, count\n\n\ndef get_norm_cat_entropy(module, masks, logits, template) -> Tuple[torch.Tensor, torch.Tensor]:\n    entropy_sum = torch.zeros_like(template, dtype=torch.float)\n    counts = torch.zeros_like(template, dtype=torch.int)\n    for k, subhead in module.items():\n        if isinstance(subhead, DictActionHead):\n            entropy, count = get_norm_cat_entropy(subhead, masks, logits[k], template)\n        elif isinstance(subhead, CategoricalActionHead):\n            entropy, count = get_norm_entropy_from_cat_head(subhead, k, masks, logits[k])\n        else:\n            continue\n        entropy_sum += entropy\n        counts += count\n    return entropy_sum, counts\n\n\ndef get_diag_guassian_entropy(module, logits, template) -> Optional[torch.Tensor]:\n    entropy_sum = torch.zeros_like(template, dtype=torch.float)\n    count = torch.zeros(1, device=template.device, dtype=torch.int)\n    for k, subhead in module.items():\n        if isinstance(subhead, DictActionHead):\n            entropy_sum += get_diag_guassian_entropy(subhead, logits[k], template)\n        elif isinstance(subhead, DiagGaussianActionHead):\n            entropy_sum += module.entropy(logits)\n        else:\n            continue\n        count += 1\n    return entropy_sum / count\n"
  },
  {
    "path": "metrics/IDM/lib/misc.py",
    "content": "import numpy as np\nimport torch as th\n\n\ndef intprod(xs):\n    \"\"\"\n    Product of a sequence of integers\n    \"\"\"\n    out = 1\n    for x in xs:\n        out *= x\n    return out\n\n\ndef safezip(*args):\n    \"\"\"\n    Check that lengths of sequences are the same, then zip them\n    \"\"\"\n    args = [list(a) for a in args]\n    n = len(args[0])\n    for arg in args[1:]:\n        assert len(arg) == n, f\"length mismatch: {list(map(len, args))}\"\n    return list(zip(*args))\n\n\ndef transpose(x, before, after):\n    \"\"\"\n    Usage: x_bca = transpose(x_abc, 'abc', 'bca')\n    \"\"\"\n    assert sorted(before) == sorted(after), f\"cannot transpose {before} to {after}\"\n    assert x.ndim == len(\n        before\n    ), f\"before spec '{before}' has length {len(before)} but x has {x.ndim} dimensions: {tuple(x.shape)}\"\n    return x.permute(tuple(before.index(i) for i in after))\n\n\ndef transpose_undo(x, before, after, *, undo=None):\n    \"\"\"\n    Usage:\n    x_bca, undo = transpose_undo(x_abc, 'abc', 'bca')\n    x_bca = fully_connected_layer(x_bca)\n    x_abc = undo(x_bca)\n    \"\"\"\n    return (\n        transpose(x, before, after),\n        compose_undo(undo, lambda x: transpose(x, before=after, after=before)),\n    )\n\n\ndef compose_undo(u1, u2):\n    assert u2 is not None\n    if u1 is None:\n        return u2\n\n    def u(x):\n        x = u2(x)\n        x = u1(x)\n        return x\n\n    return u\n\n\nNO_BIND = \"__nobind\"\n\n\ndef _parse_reshape_str(s, kind):\n    assert kind in (\"before\", \"after\")\n    result = []\n    n_underscores = 0\n    for i, part in enumerate(s.split(\",\")):\n        part = part.strip()\n        if part == \"?\" and kind == \"before\":\n            result.append([f\"__{i}\"])\n        elif part == \"_\":\n            result.append([f\"{NO_BIND}_{n_underscores}\"])\n            n_underscores += 1\n        else:\n            result.append([term.strip() for term in part.split(\"*\")])\n    return result\n\n\ndef _infer_part(part, concrete_dim, known, index, full_shape):\n    if type(part) is int:\n        return part\n    assert isinstance(part, list), part\n    lits = []\n    syms = []\n    for term in part:\n        if type(term) is int:\n            lits.append(term)\n        elif type(term) is str:\n            syms.append(term)\n        else:\n            raise TypeError(f\"got {type(term)} but expected int or str\")\n    int_part = 1\n    for x in lits:\n        int_part *= x\n    if len(syms) == 0:\n        return int_part\n    elif len(syms) == 1 and concrete_dim is not None:\n        assert concrete_dim % int_part == 0, f\"{concrete_dim} % {int_part} != 0 (at index {index}, full shape is {full_shape})\"\n        v = concrete_dim // int_part\n        if syms[0] in known:\n            assert (\n                known[syms[0]] == v\n            ), f\"known value for {syms[0]} is {known[syms[0]]} but found value {v} at index {index} (full shape is {full_shape})\"\n        else:\n            known[syms[0]] = v\n        return concrete_dim\n    else:\n        for i in range(len(syms)):\n            if syms[i] in known:\n                syms[i] = known[syms[i]]\n            else:\n                try:\n                    syms[i] = int(syms[i])\n                except ValueError:\n                    pass\n        return lits + syms\n\n\ndef _infer_step(args):\n    known, desc, shape = args\n    new_known = known.copy()\n    new_desc = desc.copy()\n    for i in range(len(desc)):\n        if shape is None:\n            concrete_dim = None\n        else:\n            concrete_dim = shape[i]\n        new_desc[i] = _infer_part(part=desc[i], concrete_dim=concrete_dim, known=new_known, index=i, full_shape=shape)\n    return new_known, new_desc, shape\n\n\ndef _infer(known, desc, shape):\n    if shape is not None:\n        assert len(desc) == len(shape), f\"desc has length {len(desc)} but shape has length {len(shape)} (shape={shape})\"\n    known, desc, shape = fixed_point(_infer_step, (known, desc, shape))\n    return desc, known\n\n\ndef fixed_point(f, x, eq=None):\n    if eq is None:\n        eq = lambda a, b: a == b\n    while True:\n        new_x = f(x)\n        if eq(x, new_x):\n            return x\n        else:\n            x = new_x\n\n\ndef _infer_question_mark(x, total_product):\n    try:\n        question_mark_index = x.index([\"?\"])\n    except ValueError:\n        return x\n    observed_product = 1\n    for i in range(len(x)):\n        if i != question_mark_index:\n            assert type(x[i]) is int, f\"when there is a question mark, there can be no other unknown values (full list: {x})\"\n            observed_product *= x[i]\n    assert (\n        observed_product and total_product % observed_product == 0\n    ), f\"{total_product} is not divisible by {observed_product}\"\n    value = total_product // observed_product\n    x = x.copy()\n    x[question_mark_index] = value\n    return x\n\n\ndef _ground(x, known, infer_question_mark_with=None):\n    x, known = _infer(known=known, desc=x, shape=None)\n    if infer_question_mark_with:\n        x = _infer_question_mark(x, infer_question_mark_with)\n    for part in x:\n        assert type(part) is int, f\"cannot infer value of {part}\"\n    return x\n\n\ndef _handle_ellipsis(x, before, after):\n    ell = [\"...\"]\n    try:\n        i = before.index(ell)\n        l = len(x.shape) - len(before) + 1\n        ellipsis_value = x.shape[i : i + l]\n        ellipsis_value = list(ellipsis_value)\n        before = before[:i] + ellipsis_value + before[i + 1 :]\n    except ValueError:\n        pass\n    try:\n        i = after.index(ell)\n        after = after[:i] + ellipsis_value + after[i + 1 :]\n    except ValueError:\n        pass\n    except UnboundLocalError as e:\n        raise ValueError(\"there cannot be an ellipsis in 'after' unless there is an ellipsis in 'before'\") from e\n    return before, after\n\n\ndef reshape_undo(inp, before, after, *, undo=None, known=None, **kwargs):\n    \"\"\"\n    Usage:\n    x_Bhwse, undo = reshape_undo(\n        x_bthwe,\n        'b, t, ..., stride*e',\n        'b*t, ..., stride, e',\n        stride=7\n    )\n    x_Bhwse = do_some_stuff(x_Bhwse)\n    x_bthwe = undo(x_Bhwse)\n\n    It's necessary to pass known values as keywords only\n    when they can't be inferred from the shape.\n\n    (Eg. in the above example we needed to pass\n    stride but not b, t, or e, since those can be determined from\n    inp.shape once stride is known.)\n    \"\"\"\n    if known:\n        known = {**kwargs, **known}\n    else:\n        known = kwargs\n    assert type(before) is type(after), f\"{type(before)} != {type(after)}\"\n    assert isinstance(inp, (th.Tensor, np.ndarray)), f\"require tensor or ndarray but got {type(inp)}\"\n    assert isinstance(before, (str, list)), f\"require str or list but got {type(before)}\"\n    if isinstance(before, str):\n        before = _parse_reshape_str(before, \"before\")\n        after = _parse_reshape_str(after, \"after\")\n        before, after = _handle_ellipsis(inp, before, after)\n    before_saved, after_saved = before, after\n    before, known = _infer(known=known, desc=before, shape=inp.shape)\n    before = _ground(before, known, product(inp.shape))\n    after = _ground(after, known, product(inp.shape))\n    known = {k: v for k, v in known.items() if not k.startswith(NO_BIND)}\n    assert tuple(inp.shape) == tuple(before), f\"expected shape {before} but got shape {inp.shape}\"\n    assert product(inp.shape) == product(\n        after\n    ), f\"cannot reshape {inp.shape} to {after} because the number of elements does not match\"\n    return (\n        inp.reshape(after),\n        compose_undo(undo, lambda inp: reshape(inp, after_saved, before_saved, known=known)),\n    )\n\n\ndef reshape(*args, **kwargs):\n    \"\"\"\n    Please see the documenation for reshape_undo.\n    \"\"\"\n    x, _ = reshape_undo(*args, **kwargs)\n    return x\n\n\ndef product(xs, one=1):\n    result = one\n    for x in xs:\n        result = result * x\n    return result\n\n\ndef exact_div(a, b):\n    assert a % b == 0, f\"{a} is not divisible by {b}\"\n    return a // b\n"
  },
  {
    "path": "metrics/IDM/lib/mlp.py",
    "content": "import torch as th\nfrom torch import nn\n\nfrom lib import misc\nfrom lib import torch_util as tu\n\n\nclass MLP(nn.Module):\n    def __init__(self, insize, nhidlayer, outsize, hidsize, hidactiv, dtype=th.float32):\n        super().__init__()\n        self.insize = insize\n        self.nhidlayer = nhidlayer\n        self.outsize = outsize\n        in_sizes = [insize] + [hidsize] * nhidlayer\n        out_sizes = [hidsize] * nhidlayer + [outsize]\n        self.layers = nn.ModuleList(\n            [tu.NormedLinear(insize, outsize, dtype=dtype) for (insize, outsize) in misc.safezip(in_sizes, out_sizes)]\n        )\n        self.hidactiv = hidactiv\n\n    def forward(self, x):\n        *hidlayers, finallayer = self.layers\n        for layer in hidlayers:\n            x = layer(x)\n            x = self.hidactiv(x)\n        x = finallayer(x)\n        return x\n\n    @property\n    def output_shape(self):\n        return (self.outsize,)\n"
  },
  {
    "path": "metrics/IDM/lib/normalize_ewma.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\n\n\nclass NormalizeEwma(nn.Module):\n    \"\"\"Normalize a vector of observations - across the first norm_axes dimensions\"\"\"\n\n    def __init__(self, input_shape, norm_axes=2, beta=0.99999, per_element_update=False, epsilon=1e-5):\n        super().__init__()\n\n        self.input_shape = input_shape\n        self.norm_axes = norm_axes\n        self.epsilon = epsilon\n        self.beta = beta\n        self.per_element_update = per_element_update\n\n        self.running_mean = nn.Parameter(torch.zeros(input_shape, dtype=torch.float), requires_grad=False)\n        self.running_mean_sq = nn.Parameter(torch.zeros(input_shape, dtype=torch.float), requires_grad=False)\n        self.debiasing_term = nn.Parameter(torch.tensor(0.0, dtype=torch.float), requires_grad=False)\n\n    def reset_parameters(self):\n        self.running_mean.zero_()\n        self.running_mean_sq.zero_()\n        self.debiasing_term.zero_()\n\n    def running_mean_var(self):\n        debiased_mean = self.running_mean / self.debiasing_term.clamp(min=self.epsilon)\n        debiased_mean_sq = self.running_mean_sq / self.debiasing_term.clamp(min=self.epsilon)\n        debiased_var = (debiased_mean_sq - debiased_mean ** 2).clamp(min=1e-2)\n        return debiased_mean, debiased_var\n\n    def forward(self, input_vector):\n        # Make sure input is float32\n        input_vector = input_vector.to(torch.float)\n\n        if self.training:\n            # Detach input before adding it to running means to avoid backpropping through it on\n            # subsequent batches.\n            detached_input = input_vector.detach()\n            batch_mean = detached_input.mean(dim=tuple(range(self.norm_axes)))\n            batch_sq_mean = (detached_input ** 2).mean(dim=tuple(range(self.norm_axes)))\n\n            if self.per_element_update:\n                batch_size = np.prod(detached_input.size()[: self.norm_axes])\n                weight = self.beta ** batch_size\n            else:\n                weight = self.beta\n\n            self.running_mean.mul_(weight).add_(batch_mean * (1.0 - weight))\n            self.running_mean_sq.mul_(weight).add_(batch_sq_mean * (1.0 - weight))\n            self.debiasing_term.mul_(weight).add_(1.0 * (1.0 - weight))\n\n        mean, var = self.running_mean_var()\n        return (input_vector - mean[(None,) * self.norm_axes]) / torch.sqrt(var)[(None,) * self.norm_axes]\n\n    def denormalize(self, input_vector):\n        \"\"\"Transform normalized data back into original distribution\"\"\"\n        mean, var = self.running_mean_var()\n        return input_vector * torch.sqrt(var)[(None,) * self.norm_axes] + mean[(None,) * self.norm_axes]\n"
  },
  {
    "path": "metrics/IDM/lib/policy.py",
    "content": "from copy import deepcopy\nfrom email import policy\nfrom typing import Dict, Optional\n\nimport numpy as np\nimport torch as th\nfrom gym3.types import DictType\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom lib.action_head import make_action_head\nfrom lib.action_mapping import CameraHierarchicalMapping\nfrom lib.impala_cnn import ImpalaCNN\nfrom lib.normalize_ewma import NormalizeEwma\nfrom lib.scaled_mse_head import ScaledMSEHead\nfrom lib.tree_util import tree_map\nfrom lib.util import FanInInitReLULayer, ResidualRecurrentBlocks\nfrom lib.misc import transpose\n\n\nclass ImgPreprocessing(nn.Module):\n    \"\"\"Normalize incoming images.\n\n    :param img_statistics: remote path to npz file with a mean and std image. If specified\n        normalize images using this.\n    :param scale_img: If true and img_statistics not specified, scale incoming images by 1/255.\n    \"\"\"\n\n    def __init__(self, img_statistics: Optional[str] = None, scale_img: bool = True):\n        super().__init__()\n        self.img_mean = None\n        if img_statistics is not None:\n            img_statistics = dict(**np.load(img_statistics))\n            self.img_mean = nn.Parameter(th.Tensor(img_statistics[\"mean\"]), requires_grad=False)\n            self.img_std = nn.Parameter(th.Tensor(img_statistics[\"std\"]), requires_grad=False)\n        else:\n            self.ob_scale = 255.0 if scale_img else 1.0\n\n    def forward(self, img):\n        x = img.to(dtype=th.float32)\n        if self.img_mean is not None:\n            x = (x - self.img_mean) / self.img_std\n        else:\n            x = x / self.ob_scale\n        return x\n\n\nclass ImgObsProcess(nn.Module):\n    \"\"\"ImpalaCNN followed by a linear layer.\n\n    :param cnn_outsize: impala output dimension\n    :param output_size: output size of the linear layer.\n    :param dense_init_norm_kwargs: kwargs for linear FanInInitReLULayer\n    :param init_norm_kwargs: kwargs for 2d and 3d conv FanInInitReLULayer\n    \"\"\"\n\n    def __init__(\n        self,\n        cnn_outsize: int,\n        output_size: int,\n        dense_init_norm_kwargs: Dict = {},\n        init_norm_kwargs: Dict = {},\n        **kwargs,\n    ):\n        super().__init__()\n        self.cnn = ImpalaCNN(\n            outsize=cnn_outsize,\n            init_norm_kwargs=init_norm_kwargs,\n            dense_init_norm_kwargs=dense_init_norm_kwargs,\n            **kwargs,\n        )\n        self.linear = FanInInitReLULayer(\n            cnn_outsize,\n            output_size,\n            layer_type=\"linear\",\n            **dense_init_norm_kwargs,\n        )\n\n    def forward(self, img):\n        return self.linear(self.cnn(img))\n\n\nclass MinecraftPolicy(nn.Module):\n    \"\"\"\n    :param recurrence_type:\n        None                - No recurrence, adds no extra layers\n        lstm                - (Depreciated). Singular LSTM\n        multi_layer_lstm    - Multi-layer LSTM. Uses n_recurrence_layers to determine number of consecututive LSTMs\n            Does NOT support ragged batching\n        multi_masked_lstm   - Multi-layer LSTM that supports ragged batching via the first vector. This model is slower\n            Uses n_recurrence_layers to determine number of consecututive LSTMs\n        transformer         - Dense transformer\n    :param init_norm_kwargs: kwargs for all FanInInitReLULayers.\n    \"\"\"\n\n    def __init__(\n        self,\n        recurrence_type=\"lstm\",\n        impala_width=1,\n        impala_chans=(16, 32, 32),\n        obs_processing_width=256,\n        hidsize=512,\n        single_output=False,  # True if we don't need separate outputs for action/value outputs\n        img_shape=None,\n        scale_input_img=True,\n        only_img_input=False,\n        init_norm_kwargs={},\n        impala_kwargs={},\n        # Unused argument assumed by forc.\n        input_shape=None,  # pylint: disable=unused-argument\n        active_reward_monitors=None,\n        img_statistics=None,\n        first_conv_norm=False,\n        diff_mlp_embedding=False,\n        attention_mask_style=\"clipped_causal\",\n        attention_heads=8,\n        attention_memory_size=2048,\n        use_pointwise_layer=True,\n        pointwise_ratio=4,\n        pointwise_use_activation=False,\n        n_recurrence_layers=1,\n        recurrence_is_residual=True,\n        timesteps=None,\n        use_pre_lstm_ln=True,  # Not needed for transformer\n        **unused_kwargs,\n    ):\n        super().__init__()\n        assert recurrence_type in [\n            \"multi_layer_lstm\",\n            \"multi_layer_bilstm\",\n            \"multi_masked_lstm\",\n            \"transformer\",\n            \"none\",\n        ]\n\n        active_reward_monitors = active_reward_monitors or {}\n\n        self.single_output = single_output\n\n        chans = tuple(int(impala_width * c) for c in impala_chans)\n        self.hidsize = hidsize\n\n        # Dense init kwargs replaces batchnorm/groupnorm with layernorm\n        self.init_norm_kwargs = init_norm_kwargs\n        self.dense_init_norm_kwargs = deepcopy(init_norm_kwargs)\n        if self.dense_init_norm_kwargs.get(\"group_norm_groups\", None) is not None:\n            self.dense_init_norm_kwargs.pop(\"group_norm_groups\", None)\n            self.dense_init_norm_kwargs[\"layer_norm\"] = True\n        if self.dense_init_norm_kwargs.get(\"batch_norm\", False):\n            self.dense_init_norm_kwargs.pop(\"batch_norm\", False)\n            self.dense_init_norm_kwargs[\"layer_norm\"] = True\n\n        # Setup inputs\n        self.img_preprocess = ImgPreprocessing(img_statistics=img_statistics, scale_img=scale_input_img)\n        self.img_process = ImgObsProcess(\n            cnn_outsize=256,\n            output_size=hidsize,\n            inshape=img_shape,\n            chans=chans,\n            nblock=2,\n            dense_init_norm_kwargs=self.dense_init_norm_kwargs,\n            init_norm_kwargs=init_norm_kwargs,\n            first_conv_norm=first_conv_norm,\n            **impala_kwargs,\n        )\n\n        self.pre_lstm_ln = nn.LayerNorm(hidsize) if use_pre_lstm_ln else None\n        self.diff_obs_process = None\n\n        self.recurrence_type = recurrence_type\n\n        self.recurrent_layer = None\n        self.recurrent_layer = ResidualRecurrentBlocks(\n            hidsize=hidsize,\n            timesteps=timesteps,\n            recurrence_type=recurrence_type,\n            is_residual=recurrence_is_residual,\n            use_pointwise_layer=use_pointwise_layer,\n            pointwise_ratio=pointwise_ratio,\n            pointwise_use_activation=pointwise_use_activation,\n            attention_mask_style=attention_mask_style,\n            attention_heads=attention_heads,\n            attention_memory_size=attention_memory_size,\n            n_block=n_recurrence_layers,\n        )\n\n        self.lastlayer = FanInInitReLULayer(hidsize, hidsize, layer_type=\"linear\", **self.dense_init_norm_kwargs)\n        self.final_ln = th.nn.LayerNorm(hidsize)\n\n    def output_latent_size(self):\n        return self.hidsize\n\n    def forward(self, ob, state_in, context):\n        first = context[\"first\"]\n\n        x = self.img_preprocess(ob[\"img\"])\n        x = self.img_process(x)\n\n        if self.diff_obs_process:\n            processed_obs = self.diff_obs_process(ob[\"diff_goal\"])\n            x = processed_obs + x\n\n        if self.pre_lstm_ln is not None:\n            x = self.pre_lstm_ln(x)\n\n        if self.recurrent_layer is not None:\n            x, state_out = self.recurrent_layer(x, first, state_in)\n        else:\n            state_out = state_in\n\n        x = F.relu(x, inplace=False)\n\n        x = self.lastlayer(x)\n        x = self.final_ln(x)\n        pi_latent = vf_latent = x\n        if self.single_output:\n            return pi_latent, state_out\n        return (pi_latent, vf_latent), state_out\n\n    def initial_state(self, batchsize):\n        if self.recurrent_layer:\n            return self.recurrent_layer.initial_state(batchsize)\n        else:\n            return None\n\n\nclass MinecraftAgentPolicy(nn.Module):\n    def __init__(self, action_space, policy_kwargs, pi_head_kwargs):\n        super().__init__()\n        self.net = MinecraftPolicy(**policy_kwargs)\n\n        self.action_space = action_space\n\n        self.value_head = self.make_value_head(self.net.output_latent_size())\n        self.pi_head = self.make_action_head(self.net.output_latent_size(), **pi_head_kwargs)\n\n    def make_value_head(self, v_out_size: int, norm_type: str = \"ewma\", norm_kwargs: Optional[Dict] = None):\n        return ScaledMSEHead(v_out_size, 1, norm_type=norm_type, norm_kwargs=norm_kwargs)\n\n    def make_action_head(self, pi_out_size: int, **pi_head_opts):\n        return make_action_head(self.action_space, pi_out_size, **pi_head_opts)\n\n    def initial_state(self, batch_size: int):\n        return self.net.initial_state(batch_size)\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.net.reset_parameters()\n        self.pi_head.reset_parameters()\n        self.value_head.reset_parameters()\n\n    def forward(self, obs, first: th.Tensor, state_in):\n        if isinstance(obs, dict):\n            # We don't want to mutate the obs input.\n            obs = obs.copy()\n\n            # If special \"mask\" key is in obs,\n            # It's for masking the logits.\n            # We take it out (the network doesn't need it)\n            mask = obs.pop(\"mask\", None)\n        else:\n            mask = None\n\n        (pi_h, v_h), state_out = self.net(obs, state_in, context={\"first\": first})\n\n        pi_logits = self.pi_head(pi_h, mask=mask)\n        vpred = self.value_head(v_h)\n\n        return (pi_logits, vpred, None), state_out\n\n    def get_logprob_of_action(self, pd, action):\n        \"\"\"\n        Get logprob of taking action `action` given probability distribution\n        (see `get_gradient_for_action` to get this distribution)\n        \"\"\"\n        ac = tree_map(lambda x: x.unsqueeze(1), action)\n        log_prob = self.pi_head.logprob(ac, pd)\n        assert not th.isnan(log_prob).any()\n        return log_prob[:, 0]\n\n    def get_kl_of_action_dists(self, pd1, pd2):\n        \"\"\"\n        Get the KL divergence between two action probability distributions\n        \"\"\"\n        return self.pi_head.kl_divergence(pd1, pd2)\n\n    def get_output_for_observation(self, obs, state_in, first):\n        \"\"\"\n        Return gradient-enabled outputs for given observation.\n\n        Use `get_logprob_of_action` to get log probability of action\n        with the given probability distribution.\n\n        Returns:\n          - probability distribution given observation\n          - value prediction for given observation\n          - new state\n        \"\"\"\n        # We need to add a fictitious time dimension everywhere\n        obs = tree_map(lambda x: x.unsqueeze(1), obs)\n        first = first.unsqueeze(1)\n\n        (pd, vpred, _), state_out = self(obs=obs, first=first, state_in=state_in)\n\n        return pd, self.value_head.denormalize(vpred)[:, 0], state_out\n\n    @th.no_grad()\n    def act(self, obs, first, state_in, stochastic: bool = True, taken_action=None, return_pd=False):\n        # We need to add a fictitious time dimension everywhere\n        obs = tree_map(lambda x: x.unsqueeze(1), obs)\n        first = first.unsqueeze(1)\n\n        (pd, vpred, _), state_out = self(obs=obs, first=first, state_in=state_in)\n\n        if taken_action is None:\n            ac = self.pi_head.sample(pd, deterministic=not stochastic)\n        else:\n            ac = tree_map(lambda x: x.unsqueeze(1), taken_action)\n        log_prob = self.pi_head.logprob(ac, pd)\n        assert not th.isnan(log_prob).any()\n\n        # After unsqueezing, squeeze back to remove fictitious time dimension\n        result = {\"log_prob\": log_prob[:, 0], \"vpred\": self.value_head.denormalize(vpred)[:, 0]}\n        if return_pd:\n            result[\"pd\"] = tree_map(lambda x: x[:, 0], pd)\n        ac = tree_map(lambda x: x[:, 0], ac)\n\n        return ac, state_out, result\n\n    @th.no_grad()\n    def v(self, obs, first, state_in):\n        \"\"\"Predict value for a given mdp observation\"\"\"\n        obs = tree_map(lambda x: x.unsqueeze(1), obs)\n        first = first.unsqueeze(1)\n\n        (pd, vpred, _), state_out = self(obs=obs, first=first, state_in=state_in)\n\n        # After unsqueezing, squeeze back\n        return self.value_head.denormalize(vpred)[:, 0]\n\n\nclass InverseActionNet(MinecraftPolicy):\n    \"\"\"\n    Args:\n        conv3d_params: PRE impala 3D CNN params. They are just passed into th.nn.Conv3D.\n    \"\"\"\n\n    def __init__(\n        self,\n        hidsize=512,\n        conv3d_params=None,\n        **MCPoliy_kwargs,\n    ):\n        super().__init__(\n            hidsize=hidsize,\n            # If we're using 3dconv, then we normalize entire impala otherwise don't\n            # normalize the first impala layer since we normalize the input\n            first_conv_norm=conv3d_params is not None,\n            **MCPoliy_kwargs,\n        )\n        self.conv3d_layer = None\n        if conv3d_params is not None:\n            # 3D conv is the first layer, so don't normalize its input\n            conv3d_init_params = deepcopy(self.init_norm_kwargs)\n            conv3d_init_params[\"group_norm_groups\"] = None\n            conv3d_init_params[\"batch_norm\"] = False\n            self.conv3d_layer = FanInInitReLULayer(\n                layer_type=\"conv3d\",\n                log_scope=\"3d_conv\",\n                **conv3d_params,\n                **conv3d_init_params,\n            )\n\n    def forward(self, ob, state_in, context):\n        first = context[\"first\"]\n        x = self.img_preprocess(ob[\"img\"])\n\n        # Conv3D Prior to Impala\n        if self.conv3d_layer is not None:\n            x = self._conv3d_forward(x)\n\n        # Impala Stack\n        x = self.img_process(x)\n\n        if self.recurrent_layer is not None:\n            x, state_out = self.recurrent_layer(x, first, state_in)\n\n        x = F.relu(x, inplace=False)\n\n        pi_latent = self.lastlayer(x)\n        pi_latent = self.final_ln(x)\n        return (pi_latent, None), state_out\n\n    def _conv3d_forward(self, x):\n        # Convert from (B, T, H, W, C) -> (B, H, W, C, T)\n        x = transpose(x, \"bthwc\", \"bcthw\")\n        new_x = []\n        for mini_batch in th.split(x, 1):\n            new_x.append(self.conv3d_layer(mini_batch))\n        x = th.cat(new_x)\n        # Convert back\n        x = transpose(x, \"bcthw\", \"bthwc\")\n        return x\n\n\nclass InverseActionPolicy(nn.Module):\n    def __init__(\n        self,\n        action_space,\n        pi_head_kwargs=None,\n        idm_net_kwargs=None,\n    ):\n        super().__init__()\n        self.action_space = action_space\n\n        self.net = InverseActionNet(**idm_net_kwargs)\n\n        pi_out_size = self.net.output_latent_size()\n\n        pi_head_kwargs = {} if pi_head_kwargs is None else pi_head_kwargs\n\n        self.pi_head = self.make_action_head(pi_out_size=pi_out_size, **pi_head_kwargs)\n\n    def make_action_head(self, **kwargs):\n        return make_action_head(self.action_space, **kwargs)\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.net.reset_parameters()\n        self.pi_head.reset_parameters()\n\n    def forward(self, obs, first: th.Tensor, state_in, **kwargs):\n        if isinstance(obs, dict):\n            # We don't want to mutate the obs input.\n            obs = obs.copy()\n\n            # If special \"mask\" key is in obs,\n            # It's for masking the logits.\n            # We take it out (the network doesn't need it)\n            mask = obs.pop(\"mask\", None)\n        else:\n            mask = None\n\n        (pi_h, _), state_out = self.net(obs, state_in=state_in, context={\"first\": first}, **kwargs)\n        pi_logits = self.pi_head(pi_h, mask=mask)\n        return (pi_logits, None, None), state_out\n\n    @th.no_grad()\n    def predict(\n        self,\n        obs,\n        deterministic: bool = True,\n        **kwargs,\n    ):\n        (pd, _, _), state_out = self(obs=obs, **kwargs)\n\n        ac = self.pi_head.sample(pd, deterministic=deterministic)\n        log_prob = self.pi_head.logprob(ac, pd)\n\n        assert not th.isnan(log_prob).any()\n\n        result = {\"log_prob\": log_prob, \"pd\": pd}\n\n        return ac, state_out, result\n\n    def initial_state(self, batch_size: int):\n        return self.net.initial_state(batch_size)\n"
  },
  {
    "path": "metrics/IDM/lib/scaled_mse_head.py",
    "content": "from typing import Dict, Optional\n\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.nn.init as init\n\nfrom lib.action_head import fan_in_linear\nfrom lib.normalize_ewma import NormalizeEwma\n\n\nclass ScaledMSEHead(nn.Module):\n    \"\"\"\n    Linear output layer that scales itself so that targets are always normalized to N(0, 1)\n    \"\"\"\n\n    def __init__(\n        self, input_size: int, output_size: int, norm_type: Optional[str] = \"ewma\", norm_kwargs: Optional[Dict] = None\n    ):\n        super().__init__()\n        self.input_size = input_size\n        self.output_size = output_size\n        self.norm_type = norm_type\n\n        self.linear = nn.Linear(self.input_size, self.output_size)\n\n        norm_kwargs = {} if norm_kwargs is None else norm_kwargs\n        self.normalizer = NormalizeEwma(output_size, **norm_kwargs)\n\n    def reset_parameters(self):\n        init.orthogonal_(self.linear.weight)\n        fan_in_linear(self.linear)\n        self.normalizer.reset_parameters()\n\n    def forward(self, input_data):\n        return self.linear(input_data)\n\n    def loss(self, prediction, target):\n        \"\"\"\n        Calculate the MSE loss between output and a target.\n        'Prediction' has to be normalized while target is denormalized.\n        Loss is calculated in a 'normalized' space.\n        \"\"\"\n        return F.mse_loss(prediction, self.normalizer(target), reduction=\"mean\")\n\n    def denormalize(self, input_data):\n        \"\"\"Convert input value from a normalized space into the original one\"\"\"\n        return self.normalizer.denormalize(input_data)\n\n    def normalize(self, input_data):\n        return self.normalizer(input_data)\n"
  },
  {
    "path": "metrics/IDM/lib/torch_util.py",
    "content": "import functools\nimport itertools\nimport math\nimport os\nimport pickle\nimport re\nimport subprocess\nimport tempfile\nfrom contextlib import contextmanager\nfrom hashlib import md5, sha1\n\nimport numpy as np\nimport torch as th\nimport torch.distributed as dist\nimport torch.distributions as dis\nimport torch.nn.functional as F\nfrom torch import nn\n\nimport lib.tree_util as tree_util\nfrom lib import misc\n\n\ndef contextmanager_to_decorator(cm):\n    def decorator(fn):\n        @functools.wraps(fn)\n        def newfn(*args, **kwargs):\n            with cm():\n                return fn(*args, **kwargs)\n\n        return newfn\n\n    return decorator\n\n\ndef have_cuda():\n    return th.has_cuda\n\n\ndef default_device_type():\n    return \"cuda\" if have_cuda() else \"cpu\"\n\n\nno_grad = contextmanager_to_decorator(th.no_grad)\nDEFAULT_DEVICE = th.device(type=default_device_type())\n\n\ndef set_default_torch_device(device):\n    global DEFAULT_DEVICE\n    DEFAULT_DEVICE = th.device(device)\n\n\ndef dev():\n    return DEFAULT_DEVICE\n\n\ndef zeros(*args, **kwargs):\n    return th.zeros(*args, **kwargs, device=dev())\n\n\ndef ones(*args, **kwargs):\n    return th.ones(*args, **kwargs, device=dev())\n\n\ndef arange(*args, **kwargs):\n    return th.arange(*args, **kwargs, device=dev())\n\n\ndef NormedLinear(*args, scale=1.0, dtype=th.float32, **kwargs):\n    \"\"\"\n    nn.Linear but with normalized fan-in init\n    \"\"\"\n    dtype = parse_dtype(dtype)\n    if dtype == th.float32:\n        out = nn.Linear(*args, **kwargs)\n    elif dtype == th.float16:\n        out = LinearF16(*args, **kwargs)\n    else:\n        raise ValueError(dtype)\n    out.weight.data *= scale / out.weight.norm(dim=1, p=2, keepdim=True)\n    if kwargs.get(\"bias\", True):\n        out.bias.data *= 0\n    return out\n\n\nclass LinearF16(nn.Linear):\n    def forward(self, x):\n        return F.linear(x, self.weight.half(), self.bias.half() if self.bias is not None else None)\n\n\nclass LayerNormF16(nn.LayerNorm):\n    def forward(self, x):\n        return F.layer_norm(x, self.normalized_shape, self.weight.half(), self.bias.half(), self.eps)\n\n\ndef LayerNorm(*args, dtype=th.float32, **kwargs):\n    dtype = parse_dtype(dtype)\n    if dtype == th.float32:\n        out = nn.LayerNorm(*args, **kwargs)\n    elif dtype == th.float16:\n        out = LayerNormF16(*args, **kwargs)\n    else:\n        raise ValueError(dtype)\n    out.weight.no_scale = True\n    return out\n\n\ndef flatten_image(x):\n    \"\"\"\n    Flattens last three dims\n    \"\"\"\n    *batch_shape, h, w, c = x.shape\n    return x.reshape((*batch_shape, h * w * c))\n\n\ndef sequential(layers, x, *args, diag_name=None, use_checkpoint=False):\n    for (i, layer) in enumerate(layers):\n        x = layer(x, *args)\n    return x\n\n\n@no_grad\ndef load_average_with_metadata(paths, overrides):\n    n_models = len(paths)\n    model, metadata = load_with_metadata(paths[0], overrides=overrides)\n    for p in model.parameters():\n        p.mul_(1 / n_models)\n    for p in paths[1:]:\n        new_model, _ = load_with_metadata(p, overrides=overrides)\n        for (n1, p1), (n2, p2) in misc.safezip(model.named_parameters(), new_model.named_parameters()):\n            assert n1 == n2, f\"names {n1} and {n2} don't match\"\n            p1.add_(p2.mul_(1 / n_models))\n    return model, metadata\n\n\ndef save_kwargs(fn):\n    \"\"\"\n    This decorator passes through the user-provided kwargs and adds one more, called\n    save_kwargs, mapping to {\"create_fn\" : name_of_decorated_fn, \"kwargs\" : other_kwargs}\n\n    You put on this decorator on a function that creates a pytorch module. This will\n    save the kwargs and the function that was used to create the module.\n    This lets us restore the model state later.\n    \"\"\"\n\n    @functools.wraps(fn)\n    def wrapper(**kwargs):\n        if \"save_kwargs\" in kwargs:\n            return fn(**kwargs)\n        else:\n            sk = {**kwargs, \"create_fn\": f\"{fn.__module__}:{fn.__name__}\"}\n            return fn(save_kwargs=sk, **kwargs)\n\n    return wrapper\n\n\ndef parse_dtype(x):\n    if isinstance(x, th.dtype):\n        return x\n    elif isinstance(x, str):\n        if x == \"float32\" or x == \"float\":\n            return th.float32\n        elif x == \"float64\" or x == \"double\":\n            return th.float64\n        elif x == \"float16\" or x == \"half\":\n            return th.float16\n        elif x == \"uint8\":\n            return th.uint8\n        elif x == \"int8\":\n            return th.int8\n        elif x == \"int16\" or x == \"short\":\n            return th.int16\n        elif x == \"int32\" or x == \"int\":\n            return th.int32\n        elif x == \"int64\" or x == \"long\":\n            return th.int64\n        elif x == \"bool\":\n            return th.bool\n        else:\n            raise ValueError(f\"cannot parse {x} as a dtype\")\n    else:\n        raise TypeError(f\"cannot parse {type(x)} as dtype\")\n\n\ndef index(x, i):\n    \"\"\"\n    Batched, broadcasting index of x along dimension i.ndim.\n\n    For example, if x has shape (1, 2, 3, 4, 5) and i has shape (1, 1, 3)\n    then the result has shape (1, 2, 3, 5) and each value in i must be between 0 and 3.\n    \"\"\"\n    assert x.ndim >= i.ndim + 1\n    gather_dim = i.ndim\n    while i.ndim < x.ndim:\n        i = i.unsqueeze(-1)\n    expand_shape = list(x.shape)\n    expand_shape[gather_dim] = 1\n    i = i.expand(*expand_shape)\n    xi = th.gather(x, gather_dim, i)\n    assert xi.shape[gather_dim] == 1\n    return xi.squeeze(gather_dim)\n"
  },
  {
    "path": "metrics/IDM/lib/tree_util.py",
    "content": "# Copyright 2018 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Copied this from jax, made it self-contained\n# Currently just used for improved_checkpoint\n\nimport collections\nimport functools\nimport itertools as it\nfrom collections.abc import Collection\nfrom typing import Dict, List, Optional\n\n\ndef unzip2(xys):\n    xs = []\n    ys = []\n    for x, y in xys:\n        xs.append(x)\n        ys.append(y)\n    return tuple(xs), tuple(ys)\n\n\ndef partial(fun, *args, **kwargs):\n    wrapped = functools.partial(fun, *args, **kwargs)\n    functools.update_wrapper(wrapped, fun)\n    wrapped._bound_args = args  # pylint: disable=protected-access\n    return wrapped\n\n\ndef safe_zip(*args: Collection) -> List[tuple]:\n    n = len(args[0])\n    for arg in args[1:]:\n        assert len(arg) == n, \"length mismatch: {}\".format(list(map(len, args)))\n    return list(zip(*args))\n\n\ndef safe_map(f, *args):\n    args = list(map(list, args))\n    n = len(args[0])\n    for arg in args[1:]:\n        assert len(arg) == n, \"length mismatch: {}\".format(list(map(len, args)))\n    return list(map(f, *args))\n\n\ndef tree_map(f, tree, treat_as_leaves: Optional[List] = None):\n    \"\"\"Map a function over a pytree to produce a new pytree.\n\n    Args:\n      f: function to be applied at each leaf.\n      tree: a pytree to be mapped over.\n\n    Returns:\n      A new pytree with the same structure as `tree` but with the value at each\n      leaf given by `f(x)` where `x` is the value at the corresponding leaf in\n      `tree`.\n    \"\"\"\n    if treat_as_leaves is None:\n        treat_as_leaves = []\n    node_type = node_types.get(type(tree))\n    if node_type and type(tree) not in treat_as_leaves:\n        children, node_spec = node_type.to_iterable(tree)\n        new_children = [tree_map(f, child, treat_as_leaves) for child in children]\n        return node_type.from_iterable(node_spec, new_children)\n    else:\n        return f(tree)\n\n\ndef tree_multimap(f, tree, *rest, treat_as_leaves: Optional[List] = None):\n    \"\"\"Map a multi-input function over pytree args to produce a new pytree.\n\n    Args:\n      f: function that takes `1 + len(rest)` arguments, to be applied at the\n        corresponding leaves of the pytrees.\n      tree: a pytree to be mapped over, with each leaf providing the first\n        positional argument to `f`.\n      *rest: a tuple of pytrees, each with the same structure as `tree`.\n\n    Returns:\n      A new pytree with the same structure as `tree` but with the value at each\n      leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf\n      in `tree` and `xs` is the tuple of values at corresponding leaves in `rest`.\n    \"\"\"\n\n    if treat_as_leaves is None:\n        treat_as_leaves = []\n    node_type = node_types.get(type(tree))\n    if node_type and type(tree) not in treat_as_leaves:\n        children, node_spec = node_type.to_iterable(tree)\n        all_children = [children]\n        for other_tree in rest:\n            other_children, other_node_data = node_type.to_iterable(other_tree)\n            if other_node_data != node_spec:\n                raise TypeError(\"Mismatch: {} != {}\".format(other_node_data, node_spec))\n            all_children.append(other_children)\n\n        new_children = [tree_multimap(f, *xs, treat_as_leaves=treat_as_leaves) for xs in zip(*all_children)]\n        return node_type.from_iterable(node_spec, new_children)\n    else:\n        return f(tree, *rest)\n\n\ndef prefix_multimap(f, treedef, tree, *rest):\n    \"\"\"Like tree_multimap but only maps down through a tree prefix.\"\"\"\n    if isinstance(treedef, PyLeaf):\n        return f(tree, *rest)\n    else:\n        node_type = node_types.get(type(tree))\n        if node_type != treedef.node_type:\n            raise TypeError(\"Mismatch: {} != {}\".format(treedef.node_type, node_type))\n        children, node_data = node_type.to_iterable(tree)\n        if node_data != treedef.node_data:\n            raise TypeError(\"Mismatch: {} != {}\".format(treedef.node_data, node_data))\n        all_children = [children]\n        for other_tree in rest:\n            other_children, other_node_data = node_type.to_iterable(other_tree)\n            if other_node_data != node_data:\n                raise TypeError(\"Mismatch: {} != {}\".format(other_node_data, node_data))\n            all_children.append(other_children)\n        all_children = zip(*all_children)\n\n        new_children = [prefix_multimap(f, td, *xs) for td, xs in zip(treedef.children, all_children)]\n        return node_type.from_iterable(node_data, new_children)\n\n\ndef walk_pytree(f_node, f_leaf, tree, treat_as_leaves: Optional[List] = None):\n    node_type = node_types.get(type(tree))\n    if treat_as_leaves is None:\n        treat_as_leaves = []\n\n    if node_type and type(tree) not in treat_as_leaves:\n        children, node_spec = node_type.to_iterable(tree)\n        proc_children, child_specs = unzip2([walk_pytree(f_node, f_leaf, child, treat_as_leaves) for child in children])\n        tree_def = PyTreeDef(node_type, node_spec, child_specs)\n        return f_node(proc_children), tree_def\n    else:\n        return f_leaf(tree), PyLeaf()\n\n\ndef build_tree(treedef, xs):\n    if isinstance(treedef, PyLeaf):\n        return xs\n    else:\n        # We use 'iter' for clearer error messages\n        children = safe_map(build_tree, iter(treedef.children), iter(xs))\n        return treedef.node_type.from_iterable(treedef.node_data, children)\n\n\ndef _tree_unflatten(xs, treedef):\n    if isinstance(treedef, PyLeaf):\n        return next(xs)\n    else:\n        children = safe_map(partial(_tree_unflatten, xs), treedef.children)\n        return treedef.node_type.from_iterable(treedef.node_data, children)\n\n\ndef _num_leaves(treedef):\n    return 1 if isinstance(treedef, PyLeaf) else sum(safe_map(_num_leaves, treedef.children))\n\n\ndef _nested_treedef(inner, outer):\n    # just used in tree_transpose error checking\n    if isinstance(outer, PyLeaf):\n        return inner\n    else:\n        children = safe_map(partial(_nested_treedef, inner), outer.children)\n        return PyTreeDef(outer.node_type, outer.node_data, tuple(children))\n\n\nclass PyTreeDef(object):\n    def __init__(self, node_type, node_data, children):\n        self.node_type = node_type\n        self.node_data = node_data\n        self.children = children\n\n    def __repr__(self):\n        if self.node_data is None:\n            data_repr = \"\"\n        else:\n            data_repr = \"[{}]\".format(self.node_data)\n\n        return \"PyTree({}{}, [{}])\".format(self.node_type.name, data_repr, \",\".join(safe_map(repr, self.children)))\n\n    def __hash__(self):\n        return hash((self.node_type, self.node_data, tuple(self.children)))\n\n    def __eq__(self, other):\n        if isinstance(other, PyLeaf):\n            return False\n        else:\n            return self.node_type == other.node_type and self.node_data == other.node_data and self.children == other.children\n\n    def __ne__(self, other):\n        return not self == other\n\n\nclass PyLeaf(object):\n    def __repr__(self):\n        return \"*\"\n\n    def __eq__(self, other):\n        return isinstance(other, PyLeaf)\n\n\nclass NodeType(object):\n    def __init__(self, name, to_iterable, from_iterable):\n        self.name = name\n        self.to_iterable = to_iterable\n        self.from_iterable = from_iterable\n\n\nnode_types: Dict[type, NodeType] = {}\n\n\ndef register_pytree_node(py_type, to_iterable, from_iterable):\n    assert py_type not in node_types\n    node_types[py_type] = NodeType(str(py_type), to_iterable, from_iterable)\n\n\ndef tuple_to_iterable(xs):\n    return xs, None\n\n\ndef tuple_from_iterable(_keys, xs):\n    return tuple(xs)\n\n\ndef list_to_iterable(xs):\n    return tuple(xs), None\n\n\ndef list_from_iterable(_keys, xs):\n    return list(xs)\n\n\ndef dict_to_iterable(xs):\n    keys = tuple(sorted(xs.keys()))\n    return tuple(map(xs.get, keys)), keys\n\n\ndef dict_from_iterable(keys, xs):\n    return dict(safe_zip(keys, xs))\n\n\ndef ordered_dict_from_iterable(keys, xs):\n    return collections.OrderedDict(safe_zip(keys, xs))\n\n\ndef default_dict_to_iterable(xs):\n    return (tuple(xs.values()), (xs.default_factory, tuple(xs.keys())))\n\n\ndef default_dict_from_iterable(keys, xs):\n    return collections.defaultdict(keys[0], safe_zip(keys[1], xs))\n\n\ndef none_to_iterable(_xs):\n    return (), None\n\n\ndef none_from_iterable(_keys, _xs):\n    return None\n\n\nregister_pytree_node(tuple, tuple_to_iterable, tuple_from_iterable)\nregister_pytree_node(list, list_to_iterable, list_from_iterable)\nregister_pytree_node(dict, dict_to_iterable, dict_from_iterable)\nregister_pytree_node(collections.OrderedDict, dict_to_iterable, ordered_dict_from_iterable)\nregister_pytree_node(collections.defaultdict, default_dict_to_iterable, default_dict_from_iterable)\nregister_pytree_node(type(None), none_to_iterable, none_from_iterable)\n"
  },
  {
    "path": "metrics/IDM/lib/util.py",
    "content": "from typing import Dict, Optional\n\nimport torch as th\nfrom torch import nn\nfrom torch.nn import functional as F\n\nimport lib.torch_util as tu\nfrom lib.masked_attention import MaskedAttention\nfrom lib.minecraft_util import store_args\nfrom lib.tree_util import tree_map\n\n\ndef get_module_log_keys_recursive(m: nn.Module):\n    \"\"\"Recursively get all keys that a module and its children want to log.\"\"\"\n    keys = []\n    if hasattr(m, \"get_log_keys\"):\n        keys += m.get_log_keys()\n    for c in m.children():\n        keys += get_module_log_keys_recursive(c)\n    return keys\n\n\nclass FanInInitReLULayer(nn.Module):\n    \"\"\"Implements a slightly modified init that correctly produces std 1 outputs given ReLU activation\n    :param inchan: number of input channels\n    :param outchan: number of output channels\n    :param layer_args: positional layer args\n    :param layer_type: options are \"linear\" (dense layer), \"conv\" (2D Convolution), \"conv3d\" (3D convolution)\n    :param init_scale: multiplier on initial weights\n    :param batch_norm: use batch norm after the layer (for 2D data)\n    :param group_norm_groups: if not None, use group norm with this many groups after the layer. Group norm 1\n        would be equivalent of layernorm for 2D data.\n    :param layer_norm: use layernorm after the layer (for 1D data)\n    :param layer_kwargs: keyword arguments for the layer\n    \"\"\"\n\n    @store_args\n    def __init__(\n        self,\n        inchan: int,\n        outchan: int,\n        *layer_args,\n        layer_type: str = \"conv\",\n        init_scale: int = 1,\n        batch_norm: bool = False,\n        batch_norm_kwargs: Dict = {},\n        group_norm_groups: Optional[int] = None,\n        layer_norm: bool = False,\n        use_activation=True,\n        log_scope: Optional[str] = None,\n        **layer_kwargs,\n    ):\n        super().__init__()\n\n        # Normalization\n        self.norm = None\n        if batch_norm:\n            self.norm = nn.BatchNorm2d(inchan, **batch_norm_kwargs)\n        elif group_norm_groups is not None:\n            self.norm = nn.GroupNorm(group_norm_groups, inchan)\n        elif layer_norm:\n            self.norm = nn.LayerNorm(inchan)\n\n        layer = dict(conv=nn.Conv2d, conv3d=nn.Conv3d, linear=nn.Linear)[layer_type]\n        self.layer = layer(inchan, outchan, bias=self.norm is None, *layer_args, **layer_kwargs)\n\n        # Init Weights (Fan-In)\n        self.layer.weight.data *= init_scale / self.layer.weight.norm(\n            dim=tuple(range(1, self.layer.weight.data.ndim)), p=2, keepdim=True\n        )\n        # Init Bias\n        if self.layer.bias is not None:\n            self.layer.bias.data *= 0\n\n    def forward(self, x):\n        \"\"\"Norm after the activation. Experimented with this for both IAM and BC and it was slightly better.\"\"\"\n        if self.norm is not None:\n            x = self.norm(x)\n        x = self.layer(x)\n        if self.use_activation:\n            x = F.relu(x, inplace=True)\n        return x\n\n    def get_log_keys(self):\n        return [\n            f\"activation_mean/{self.log_scope}\",\n            f\"activation_std/{self.log_scope}\",\n        ]\n\n\nclass ResidualRecurrentBlocks(nn.Module):\n    @store_args\n    def __init__(\n        self,\n        n_block=2,\n        recurrence_type=\"multi_layer_lstm\",\n        is_residual=True,\n        **block_kwargs,\n    ):\n        super().__init__()\n        init_scale = n_block ** -0.5 if is_residual else 1\n        self.blocks = nn.ModuleList(\n            [\n                ResidualRecurrentBlock(\n                    **block_kwargs,\n                    recurrence_type=recurrence_type,\n                    is_residual=is_residual,\n                    init_scale=init_scale,\n                    block_number=i,\n                )\n                for i in range(n_block)\n            ]\n        )\n\n    def forward(self, x, first, state):\n        state_out = []\n        assert len(state) == len(\n            self.blocks\n        ), f\"Length of state {len(state)} did not match length of blocks {len(self.blocks)}\"\n        for block, _s_in in zip(self.blocks, state):\n            x, _s_o = block(x, first, _s_in)\n            state_out.append(_s_o)\n        return x, state_out\n\n    def initial_state(self, batchsize):\n        if \"lstm\" in self.recurrence_type:\n            return [None for b in self.blocks]\n        else:\n            return [b.r.initial_state(batchsize) for b in self.blocks]\n\n\nclass ResidualRecurrentBlock(nn.Module):\n    @store_args\n    def __init__(\n        self,\n        hidsize,\n        timesteps,\n        init_scale=1,\n        recurrence_type=\"multi_layer_lstm\",\n        is_residual=True,\n        use_pointwise_layer=True,\n        pointwise_ratio=4,\n        pointwise_use_activation=False,\n        attention_heads=8,\n        attention_memory_size=2048,\n        attention_mask_style=\"clipped_causal\",\n        log_scope=\"resblock\",\n        block_number=0,\n    ):\n        super().__init__()\n        self.log_scope = f\"{log_scope}{block_number}\"\n        s = init_scale\n        if use_pointwise_layer:\n            if is_residual:\n                s *= 2 ** -0.5  # second residual\n            self.mlp0 = FanInInitReLULayer(\n                hidsize,\n                hidsize * pointwise_ratio,\n                init_scale=1,\n                layer_type=\"linear\",\n                layer_norm=True,\n                log_scope=self.log_scope + \"/ptwise_mlp0\",\n            )\n            self.mlp1 = FanInInitReLULayer(\n                hidsize * pointwise_ratio,\n                hidsize,\n                init_scale=s,\n                layer_type=\"linear\",\n                use_activation=pointwise_use_activation,\n                log_scope=self.log_scope + \"/ptwise_mlp1\",\n            )\n\n        self.pre_r_ln = nn.LayerNorm(hidsize)\n        if recurrence_type in [\"multi_layer_lstm\", \"multi_layer_bilstm\"]:\n            self.r = nn.LSTM(hidsize, hidsize, batch_first=True)\n            nn.init.normal_(self.r.weight_hh_l0, std=s * (self.r.weight_hh_l0.shape[0] ** -0.5))\n            nn.init.normal_(self.r.weight_ih_l0, std=s * (self.r.weight_ih_l0.shape[0] ** -0.5))\n            self.r.bias_hh_l0.data *= 0\n            self.r.bias_ih_l0.data *= 0\n        elif recurrence_type == \"transformer\":\n            self.r = MaskedAttention(\n                input_size=hidsize,\n                timesteps=timesteps,\n                memory_size=attention_memory_size,\n                heads=attention_heads,\n                init_scale=s,\n                norm=\"none\",\n                log_scope=log_scope + \"/sa\",\n                use_muP_factor=True,\n                mask=attention_mask_style,\n            )\n\n    def forward(self, x, first, state):\n        residual = x\n        x = self.pre_r_ln(x)\n        x, state_out = recurrent_forward(\n            self.r,\n            x,\n            first,\n            state,\n            reverse_lstm=self.recurrence_type == \"multi_layer_bilstm\" and (self.block_number + 1) % 2 == 0,\n        )\n        if self.is_residual and \"lstm\" in self.recurrence_type:  # Transformer already residual.\n            x = x + residual\n        if self.use_pointwise_layer:\n            # Residual MLP\n            residual = x\n            x = self.mlp1(self.mlp0(x))\n            if self.is_residual:\n                x = x + residual\n        return x, state_out\n\n\ndef recurrent_forward(module, x, first, state, reverse_lstm=False):\n    if isinstance(module, nn.LSTM):\n        if state is not None:\n            # In case recurrent models do not accept a \"first\" argument we zero out the hidden state here\n            mask = 1 - first[:, 0, None, None].to(th.float)\n            state = tree_map(lambda _s: _s * mask, state)\n            state = tree_map(lambda _s: _s.transpose(0, 1), state)  # NL, B, H\n        if reverse_lstm:\n            x = th.flip(x, [1])\n        x, state_out = module(x, state)\n        if reverse_lstm:\n            x = th.flip(x, [1])\n        state_out = tree_map(lambda _s: _s.transpose(0, 1), state_out)  # B, NL, H\n        return x, state_out\n    else:\n        return module(x, first, state)\n\n\ndef _banded_repeat(x, t):\n    \"\"\"\n    Repeats x with a shift.\n    For example (ignoring the batch dimension):\n\n    _banded_repeat([A B C D E], 4)\n    =\n    [D E 0 0 0]\n    [C D E 0 0]\n    [B C D E 0]\n    [A B C D E]\n    \"\"\"\n    b, T = x.shape\n    x = th.cat([x, x.new_zeros(b, t - 1)], dim=1)\n    result = x.unfold(1, T, 1).flip(1)\n    return result\n\n\ndef bandify(b_nd, t, T):\n    \"\"\"\n    b_nd -> D_ntT, where\n        \"n\" indexes over basis functions\n        \"d\" indexes over time differences\n        \"t\" indexes over output time\n        \"T\" indexes over input time\n        only t >= T is nonzero\n    B_ntT[n, t, T] = b_nd[n, t - T]\n    \"\"\"\n    nbasis, bandsize = b_nd.shape\n    b_nd = b_nd[:, th.arange(bandsize - 1, -1, -1)]\n    if bandsize >= T:\n        b_nT = b_nd[:, -T:]\n    else:\n        b_nT = th.cat([b_nd.new_zeros(nbasis, T - bandsize), b_nd], dim=1)\n    D_tnT = _banded_repeat(b_nT, t)\n    return D_tnT\n\n\ndef get_norm(name, d, dtype=th.float32):\n    if name == \"none\":\n        return lambda x: x\n    elif name == \"layer\":\n        return tu.LayerNorm(d, dtype=dtype)\n    else:\n        raise NotImplementedError(name)\n"
  },
  {
    "path": "metrics/IDM/lib/xf.py",
    "content": "\"\"\"\nImplementation of transformer and reshaping-based sparse transformer\n\"\"\"\nimport functools\nimport math\n\nimport torch as th\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom lib import misc, mlp\nfrom lib import torch_util as tu\nfrom lib import util\n\nSENTINEL = 0.1337\n\n\ndef attention(\n    Q_bte,\n    K_bTe,\n    V_bTe,\n    dtype,\n    mask=True,\n    extra_btT=None,\n    maxlen=None,\n    check_sentinel=False,\n    use_muP_factor=False,\n):\n    \"\"\"\n    performs softmax(Q*K)*V operation\n\n    t : output (write) time axis, possibly size=1 for just the last timestep\n    T : input (read) time axis\n    t < T is OK\n\n    'check_sentinel' is used when you want to make it impossible to attend to certain keys.\n    All keys where every value is equal to the constant SENTINEL will be ignored.\n    Currently this is only used by StridedAttn.\n    \"\"\"\n    assert Q_bte.dtype == K_bTe.dtype == dtype, f\"{Q_bte.dtype}, {K_bTe.dtype}, {dtype} must all match\"\n    e = Q_bte.shape[2]\n    if check_sentinel:\n        invalid = (K_bTe == SENTINEL).int().sum(dim=-1) == e\n        invalid = misc.reshape(invalid, \"b, T\", \"b, 1, T\")\n    if isinstance(mask, th.Tensor):\n        bias = (~mask).float() * -1e9\n    elif mask:\n        bias = get_attn_bias_cached(Q_bte.shape[1], K_bTe.shape[1], maxlen=maxlen, device=Q_bte.device, dtype=th.float32)\n    else:\n        bias = Q_bte.new_zeros((), dtype=th.float32)\n    if extra_btT is not None:\n        bias = bias + extra_btT\n    # Equivalent to bias + (1 / math.sqrt(e)) * th.einsum(\"bte,bpe->btp\", Q_bte, K_bte)\n    # but faster:\n    logit_btT = th.baddbmm(\n        bias,\n        Q_bte.float(),\n        K_bTe.float().transpose(-1, -2),\n        alpha=(1 / e) if use_muP_factor else (1 / math.sqrt(e)),\n    )\n    if check_sentinel:\n        logit_btT = logit_btT - 1e9 * invalid.float()\n    W_btT = th.softmax(logit_btT, dim=2).to(dtype)\n    if callable(V_bTe):\n        # This is used by the sharded video model to defer waiting on\n        # the broadcast of the values until they're needed\n        V_bTe = V_bTe()\n    # th.einsum only lets you use lowercase letters, so 'p' for 'past'\n    # means 'T'\n    A_bte = th.einsum(\"btp,bpe->bte\", W_btT, V_bTe)\n    return A_bte\n\n\nclass Attn:\n    \"\"\"\n    Defines an attention mechanism\n    All the mechanisms here can be defined by two operations:\n    1. preprocessing Q,K,V,R[=relative attention query]\n        to move axes from embedding dimension to\n        batch dimension, and possibly doing shifts.\n    2. postprocessing the final result to move axes back to embedding\n        axis.\n    \"\"\"\n\n    def __init__(self, mask, maxlen):\n        self.mask = mask\n        self.maxlen = maxlen\n\n    def preproc_qkv(self, Q_bte, K_bte, V_bte):\n        raise NotImplementedError\n\n    def preproc_r(self, R_btn):\n        raise NotImplementedError\n\n\ndef split_heads(x_bte, h):\n    b, t, e = x_bte.shape\n    assert e % h == 0, \"Embsize must be divisible by number of heads\"\n    q = e // h\n    x_bthq = x_bte.reshape((b, t, h, q))\n    x_bhtq = misc.transpose(x_bthq, \"bthq\", \"bhtq\")\n    x_Btq = x_bhtq.reshape((b * h, t, q))\n    return x_Btq\n\n\nclass All2All(Attn):\n    def __init__(self, nhead, maxlen, mask=True, head_dim=None):\n        super().__init__(mask=mask, maxlen=maxlen)\n        assert (nhead is None) != (head_dim is None), \"exactly one of nhead and head_dim must be specified\"\n        self.h = nhead\n        self.head_dim = head_dim\n\n    def preproc_qkv(self, *xs):\n        q = xs[0].shape[-1]\n        for x in xs:\n            assert x.shape[-1] == q, \"embedding dimensions do not match\"\n        h = self.h or misc.exact_div(q, self.head_dim)\n        postproc = functools.partial(self.postproc_a, h=h)\n        return (postproc, *tuple(split_heads(x, h) for x in xs))\n\n    def preproc_r(self, R_btn):\n        _, ret = self.preproc_qkv(R_btn)\n        return ret\n\n    def postproc_a(self, A_Btq, h):\n        B, t, q = A_Btq.shape\n        b = B // h\n        A_bhtq = A_Btq.reshape((b, h, t, q))\n        A_bthq = misc.transpose(A_bhtq, \"bhtq\", \"bthq\")\n        A_bte = A_bthq.reshape((b, t, h * q))\n        return A_bte\n\n\ndef _required_padding(dim, target_div):\n    if dim % target_div == 0:\n        return 0\n    else:\n        return target_div - dim % target_div\n\n\nclass StridedAttn(Attn):\n    def __init__(self, nhead, stride, maxlen, mask=True):\n        super().__init__(mask=mask, maxlen=maxlen)\n        self.h = nhead\n        self.stride = stride\n\n    def _preproc(self, x, name, Q_t=None, Q_pad=None):\n        x, undo = misc.reshape_undo(x, \"b, t*stride, e\", \"b, 1, t, stride*e\", stride=self.stride)\n        if name == \"Q\":\n            Q_pad = _required_padding(x.shape[2], self.maxlen)\n        original_t = x.shape[2]\n        x = F.pad(x, (0, 0, 0, Q_pad), value=SENTINEL)\n        undo = misc.compose_undo(undo, lambda x: x[:, :, :original_t])\n        if name == \"Q\":\n            Q_t = x.shape[2]\n            assert Q_t % self.maxlen == 0, f\"{Q_t} % {self.maxlen} != 0\"\n        else:\n            required_len = Q_t + self.maxlen\n            if x.shape[2] < required_len:\n                x = F.pad(x, (0, 0, required_len - x.shape[2], 0), value=SENTINEL)\n            assert x.shape[2] >= required_len\n            back = x[:, :, -Q_t - self.maxlen : -self.maxlen]\n            front = x[:, :, -Q_t:]\n            x = th.cat([back, front], dim=1)\n        _, _, t, _ = x.shape\n        assert t == Q_t, f\"{t} != {Q_t}\"\n        x, undo = misc.reshape_undo(\n            x,\n            \"b, pad_shift, t*maxlen, stride*h*q\",\n            \"b, pad_shift, t, maxlen, stride, h, q\",\n            maxlen=self.maxlen,\n            h=self.h,\n            stride=self.stride,\n            undo=undo,\n        )\n        x, undo = misc.transpose_undo(x, \"bptmshq\", \"bthspmq\", undo=undo)\n        x, undo = misc.reshape_undo(\n            x,\n            \"b, t, h, stride, pad_shift, maxlen, q\",\n            \"b*t*h*stride, pad_shift*maxlen, q\",\n            undo=undo,\n        )\n        if name == \"Q\":\n            return x, undo, Q_t, Q_pad\n        else:\n            return x\n\n    def preproc_qkv(self, Q_bte, K_bte, V_bte):\n        pad = _required_padding(Q_bte.shape[1], self.stride)\n        if pad:\n            Q_bte = F.pad(Q_bte, (0, 0, 0, pad), value=SENTINEL)\n            K_bte = F.pad(K_bte, (0, 0, 0, pad), value=SENTINEL) if K_bte is not None else None\n            V_bte = F.pad(V_bte, (0, 0, 0, pad), value=SENTINEL) if V_bte is not None else None\n            undo = lambda x, pad=pad: x[:, :-pad]\n        else:\n            undo = None\n        if K_bte is not None:\n            pad = _required_padding(K_bte.shape[1], self.stride)\n            if pad:\n                K_bte = F.pad(K_bte, (0, 0, pad, 0), value=SENTINEL)\n                V_bte = F.pad(V_bte, (0, 0, pad, 0), value=SENTINEL)\n        assert Q_bte.shape[1] % self.stride == 0\n        assert K_bte is None or K_bte.shape[1] % self.stride == 0\n        assert V_bte is None or V_bte.shape[1] % self.stride == 0\n        Q, postproc, Q_t, Q_pad = self._preproc(Q_bte, \"Q\")\n        postproc = misc.compose_undo(undo, postproc)\n        return (\n            postproc,\n            Q,\n            self._preproc(K_bte, \"K\", Q_t=Q_t, Q_pad=Q_pad) if K_bte is not None else None,\n            self._preproc(V_bte, \"V\", Q_t=Q_t, Q_pad=Q_pad) if V_bte is not None else None,\n        )\n\n    def preproc_r(self, R_bte):\n        _, R, _, _ = self.preproc_qkv(R_bte, None, None)\n        return R\n\n\nQ_SCALE = 0.1\nK_SCALE = 0.2\nV_SCALE = 1.0\nPROJ_SCALE = 1.0\nMLP0_SCALE = 1.0\nMLP1_SCALE = 1.0\nR_SCALE = 0.1\nB_SCALE = 0.2\n\n\nclass AttentionLayerBase(nn.Module):\n    def __init__(\n        self,\n        *,\n        attn,\n        scale,\n        x_size,\n        c_size,\n        qk_size,\n        v_size,\n        dtype,\n        relattn=False,\n        seqlens=None,\n        separate=False,\n    ):\n        super().__init__()\n        dtype = tu.parse_dtype(dtype)\n        self.attn = attn\n        self.x_size = x_size\n        self.c_size = c_size\n        s = math.sqrt(scale)\n        separgs = dict(seqlens=seqlens, separate=separate)\n        self.q_layer = MultiscaleLinear(x_size, qk_size, name=\"q\", scale=Q_SCALE, dtype=dtype, **separgs)\n        self.k_layer = MultiscaleLinear(c_size, qk_size, name=\"k\", scale=K_SCALE, bias=False, dtype=dtype, **separgs)\n        self.v_layer = MultiscaleLinear(c_size, v_size, name=\"v\", scale=V_SCALE * s, bias=False, dtype=dtype, **separgs)\n        self.proj_layer = MultiscaleLinear(v_size, x_size, name=\"proj\", scale=PROJ_SCALE * s, dtype=dtype, **separgs)\n        self.relattn = relattn\n        maxlen = attn.maxlen\n        assert maxlen > 0 or not attn.mask\n        if self.relattn:\n            nbasis = 10\n            self.r_layer = tu.NormedLinear(x_size, nbasis * attn.h, scale=R_SCALE, dtype=dtype)\n            self.b_nd = nn.Parameter(th.randn(nbasis, maxlen) * B_SCALE)\n        self.maxlen = maxlen\n        self.dtype = dtype\n\n    def relattn_logits(self, X_bte, T):\n        R_btn = self.r_layer(X_bte).float()\n        R_btn = self.attn.preproc_r(R_btn)\n        t = R_btn.shape[1]\n        D_ntT = util.bandify(self.b_nd, t, T)\n        extra_btT = th.einsum(\"btn,ntp->btp\", R_btn, D_ntT)\n        return extra_btT\n\n\ndef quick_gelu(x):\n    return x * th.sigmoid(1.702 * x)\n\n\ndef act(actname, x):\n    if actname == \"relu\":\n        return F.relu(x)\n    elif actname == \"gelu\":\n        return quick_gelu(x)\n    elif actname == \"none\":\n        return x\n    else:\n        raise NotImplementedError(actname)\n\n\nclass SelfAttentionLayer(AttentionLayerBase):\n    \"\"\"\n    Residual attention layer that takes a single tensor x and has it attend to itself\n    Has the form\n        output = x + f(x)\n    \"\"\"\n\n    def __init__(\n        self,\n        x_size,\n        attn,\n        scale,\n        dtype=\"float32\",\n        norm=\"layer\",\n        cache_keep_len=None,\n        relattn=False,\n        log_scope=\"sa\",\n        use_muP_factor=False,\n        **kwargs,\n    ):\n        super().__init__(\n            x_size=x_size,\n            c_size=x_size,\n            qk_size=x_size,\n            v_size=x_size,\n            attn=attn,\n            scale=scale,\n            relattn=relattn,\n            dtype=dtype,\n            **kwargs,\n        )\n        self.ln_x = util.get_norm(norm, x_size, dtype=dtype)\n        if cache_keep_len is None:\n            if hasattr(attn, \"cache_keep_len\"):\n                cache_keep_len = attn.cache_keep_len\n            else:\n                if isinstance(attn, StridedAttn):\n                    stride = attn.stride\n                else:\n                    stride = 1\n                cache_keep_len = stride * attn.maxlen\n        self.cache_keep_len = cache_keep_len\n        self.log_scope = log_scope\n        self.use_muP_factor = use_muP_factor\n\n    def residual(self, X_bte, state):\n        X_bte = self.ln_x(X_bte)\n        Q_bte = self.q_layer(X_bte)\n        K_bte = self.k_layer(X_bte)\n        V_bte = self.v_layer(X_bte)\n        if state:\n            state, K_bte, V_bte = self.update_state(state, K_bte, V_bte)\n        postproc_closure, Q_bte, K_bte, V_bte = self.attn.preproc_qkv(Q_bte, K_bte, V_bte)\n        extra_btT = self.relattn_logits(X_bte, K_bte.shape[1]) if self.relattn else None\n        A_bte = attention(\n            Q_bte,\n            K_bte,\n            V_bte,\n            mask=self.attn.mask,\n            extra_btT=extra_btT,\n            maxlen=self.maxlen,\n            dtype=self.dtype,\n            check_sentinel=isinstance(self.attn, StridedAttn),\n            use_muP_factor=self.use_muP_factor,\n        )\n        A_bte = postproc_closure(A_bte)\n        Aproj_bte = self.proj_layer(A_bte)\n        return Aproj_bte, state\n\n    def forward(self, X_bte, state):\n        R_bte, state = self.residual(X_bte, state)\n        return X_bte + R_bte, state\n\n    def stateless_forward(self, X_bte):\n        out_bte, _state = self.forward(X_bte, None)\n        return out_bte\n\n    def update_state(self, state, K_bte, V_bte):\n        def append(prev, new):\n            \"\"\"\n            Given `prev` keys from cache, and `new` keys,\n            returns (cache, full), where\n            - cache goes into the output state, length chosen so that on the\n                next timestep, there are enough cached timesteps to get the full\n                context of lenth self.maxlen.\n            - full is used for the current forward pass, with length chosen so\n                that the first timestep new[:, 0] gets to see a context of\n                self.maxlen.\n            \"\"\"\n            tprev = prev.shape[1]\n            startfull = max(tprev - self.cache_keep_len, 0)\n            full = th.cat([prev[:, startfull:], new], dim=1)\n            outstate = full[:, max(full.shape[1] - (self.cache_keep_len), 0) :]\n            # To see that the preceding slicing is correct, consider the case\n            # that maxlen==1. Then `full` only consists of `new`, and\n            # `outstate` is empty\n            return outstate, full\n\n        instate_K, instate_V = state\n        outstate_K, K_bte = append(instate_K, K_bte)\n        outstate_V, V_bte = append(instate_V, V_bte)\n        assert outstate_K.shape[-2] <= self.cache_keep_len\n        return (outstate_K, outstate_V), K_bte, V_bte\n\n    def initial_state(self, batchsize, initial_T=0):\n        return (\n            tu.zeros((batchsize, initial_T, self.x_size), dtype=self.dtype),\n            tu.zeros((batchsize, initial_T, self.x_size), dtype=self.dtype),\n        )\n\n    def empty_state(self):\n        return None\n\n\nclass PointwiseLayer(nn.Module):\n    \"\"\"\n    Residual MLP applied at each timestep\n    \"\"\"\n\n    def __init__(self, x_size, scale, dtype, norm, actname=\"relu\", mlp_ratio=2):\n        super().__init__()\n        s = math.sqrt(scale)\n        self.ln = util.get_norm(norm, x_size, dtype=dtype)\n        self.mlp = mlp.MLP(\n            insize=x_size,\n            nhidlayer=1,\n            outsize=x_size,\n            hidsize=int(x_size * mlp_ratio),\n            hidactiv=functools.partial(act, actname),\n            dtype=dtype,\n        )\n        self.mlp.layers[0].weight.data *= MLP0_SCALE * s\n        self.mlp.layers[1].weight.data *= MLP1_SCALE * s\n\n    def residual(self, x):\n        x = self.ln(x)\n        x = self.mlp(x)\n        return x\n\n    def forward(self, x):\n        return x + self.residual(x)\n\n\ndef _is_separate(sep, name):\n    if isinstance(sep, bool):\n        return sep\n    assert isinstance(sep, set)\n    if name in sep:\n        sep.remove(name)\n        return True\n    else:\n        return False\n\n\ndef make_maybe_multiscale(make_fn, *args, seqlens, separate, name, **kwargs):\n    \"\"\"\n    This function either creates one instance of a module or creates\n    a separate instance of the module for each resolution of the image,\n    determined by the `separate` parameter. We create separate modules\n    if `separate` is True or if `separate` is a set containing `name`.\n    \"\"\"\n    if _is_separate(separate, name):\n        modules = [make_fn(*args, **kwargs) for _ in seqlens]\n        return SplitCallJoin(modules, seqlens)\n    else:\n        return make_fn(*args, **kwargs)\n\n\nclass SplitCallJoin(nn.Module):\n    def __init__(self, mods, seqlens):\n        super().__init__()\n        self.mods = nn.ModuleList(mods)\n        self.seqlens = seqlens\n\n    def forward(self, x):\n        tl = sum(self.seqlens)\n        x, undo = misc.reshape_undo(x, \"..., z*tl, e\", \"..., z, tl, e\", tl=tl)\n        x = list(th.split(x, self.seqlens, dim=-2))\n        new_x = []\n        for x, mod in misc.safezip(x, self.mods):\n            x, this_undo = misc.reshape_undo(x, \"..., z, l, e\", \"..., z*l, e\")\n            x = mod(x)\n            x = this_undo(x)\n            new_x.append(x)\n        x = th.cat(new_x, dim=-2)\n        x = undo(x)\n        return x\n\n\nMultiscaleLinear = functools.partial(make_maybe_multiscale, tu.NormedLinear)\nMultiscalePointwise = functools.partial(make_maybe_multiscale, PointwiseLayer)\n"
  },
  {
    "path": "metrics/common_metrics.py",
    "content": "import sys\nimport os\nsys.path.append(os.getcwd())\nfrom common_metrics_on_video_quality.calculate_fvd import calculate_fvd\nfrom common_metrics_on_video_quality.calculate_lpips import calculate_lpips\nfrom common_metrics_on_video_quality.calculate_ssim import calculate_ssim\nfrom common_metrics_on_video_quality.calculate_psnr import calculate_psnr\nimport os\nimport cv2\nimport torch\nimport numpy as np\nimport argparse\nimport json\ndevice = torch.device(\"cuda\")\n\ndef load_videos_to_tensor(video_dir, number_of_videos, video_length, channel, size,video_files=None):\n    videos_tensor = torch.zeros(number_of_videos, video_length, channel, size[0], size[1], requires_grad=False)\n    if video_files is None:\n        video_files = [f for f in os.listdir(video_dir) if f.endswith(('.mp4'))]\n    video_files = sorted(video_files, key=lambda x: int(x.split(\"_\")[-1].split(\".\")[0]))\n    video_files = video_files[:number_of_videos]\n    for i, video_file in enumerate(video_files):\n        video_path = os.path.join(video_dir, video_file)\n\n        cap = cv2.VideoCapture(video_path)\n        if not cap.isOpened():\n            print(f\"Failed to open video: {video_path}\")\n            continue\n\n        frames = []\n        # get video total length ; our gt has 16 frame but we only use 15 frame so set video_length to 15 and start frame to 1 \n        real_video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n        if real_video_length > video_length:\n            # set start frame to video_length - video_length\n            start_frame = real_video_length - video_length\n            cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)\n        else:\n            # set start frame to 0\n            cap.set(cv2.CAP_PROP_POS_FRAMES, 0)\n        while len(frames) < video_length:\n            ret, frame = cap.read()\n            if not ret: \n                break\n            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n            frame = cv2.resize(frame, (size[1], size[0]))  # Resize to (height, width)\n            frames.append(frame)\n        if len(frames) < video_length:\n            print(f\"Video {video_file} has fewer frames than expected. Expected: {video_length}, Found: {len(frames)} Exiting...\")\n            exit(1)\n        cap.release()\n        frames_np = np.array(frames[:video_length]) \n        frames_np = np.transpose(frames_np, (0, 3, 1, 2)) \n        videos_tensor[i] = torch.tensor(frames_np, dtype=torch.float32) / 255.0 \n\n    return videos_tensor\n\n# python scripts/tvideo/mc/common_metrics.py --video_dir1 metrics_table_1/oasis/oasis_official_results_no_demo_1gen15_mineworld_curation --video_dir2 metrics_table_1/frame_16_curation --video_length 15 --channel 3 --size \"(224,384)\" --output-file test_metrics.json \ndef main():\n    parser = argparse.ArgumentParser(description=\"Calculate FVD for two sets of videos.\")\n    parser.add_argument(\"--video_dir1\", type=str, required=True, help=\"Path to the first directory containing videos.\")\n    parser.add_argument(\"--video_dir2\", type=str, required=True, help=\"Path to the second directory containing videos.\")\n    parser.add_argument(\"--video_length\", type=int, default=32, help=\"Number of frames to retain from each video.\")\n    parser.add_argument(\"--channel\", type=int, default=3, help=\"Number of channels in the videos (default: 3 for RGB).\")\n    parser.add_argument(\"--size\", type=str, default=\"(224,384)\", help=\"Size of the video frames (default: 256x256).\")\n    parser.add_argument(\"--output-file\", type=str)\n    args = parser.parse_args()\n    args.size = eval(args.size)\n    print(\"args.size\", args.size)\n    number_of_videos = len([f for f in os.listdir(args.video_dir1) if f.endswith(\".mp4\")])\n    video_files = [f for f in os.listdir(args.video_dir1) if f.endswith(('.mp4'))]\n    number_of_videos = min(500,len(video_files))\n    print(\"number_of_videos\", number_of_videos)\n    videos1 = load_videos_to_tensor(args.video_dir1, number_of_videos, args.video_length, args.channel, args.size, video_files)\n    videos2 = load_videos_to_tensor(args.video_dir2, number_of_videos, args.video_length, args.channel, args.size, video_files)\n    print(\"videos1.shape\", videos1.shape, \"videos2.shape\", videos2.shape)\n    device = torch.device(\"cuda\")\n\n    print(args.output_file)\n    result = {}\n    result['fvd'] = calculate_fvd(videos1, videos2, device, method='styleganv')\n    # result['fvd'] = calculate_fvd(videos1, videos2, device, method='videogpt')\n    result['ssim'] = calculate_ssim(videos1, videos2)\n    result['psnr'] = calculate_psnr(videos1, videos2)\n    result['lpips'] = calculate_lpips(videos1, videos2, device)\n    lpips_mean =  np.mean(list(result['lpips']['value']))\n    ssim_mean =  np.mean(list(result['ssim']['value']))\n    psnr_mean =  np.mean(list(result['psnr']['value']))\n    fvd_mean =  np.mean(list(result['fvd']['value']))\n    data_item = {\"exp_name\":args.video_dir1, \"fvd\":fvd_mean, \"lpips\":lpips_mean, \"ssim\":ssim_mean, \"psnr\":psnr_mean}\n    print(data_item)\n    os.makedirs(os.path.dirname(args.output_file), exist_ok=True)\n    result['mean'] = data_item\n    with open(args.output_file, \"w\") as f:\n        json.dump(result, f, indent=4)\n    print(\"results saved to \", args.output_file)\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "metrics/tabulate_all_results.py",
    "content": "import argparse\nimport os \nimport json\nimport sys\nimport pandas as pd\nimport numpy as np\nfrom rich import print\n\ndef tabluate_metrics(input_dir,output_path):\n    metrics_list = []\n    all_files = [f for f in os.listdir(input_dir) if f.endswith('.json')]\n    idm_results = [f for f in all_files if 'idm' in f]\n    fvd_results = [f for f in all_files if 'fvd' in f]\n    exps = set([i.replace(\"idm_\",\"\").replace(\".json\",\"\") for i in idm_results]) & set([i.replace(\"fvd_\",\"\").replace(\".json\",\"\") for i in fvd_results])\n    exps = list(exps)\n    print(f\"[bold magenta][Tabulating Evaluation Results][/bold magenta]: Found experiments : {exps}\")\n    for exp in exps:\n        idm_file = os.path.join(input_dir, f\"idm_{exp}.json\")\n        fvd_file = os.path.join(input_dir, f\"fvd_{exp}.json\")\n        with open(idm_file, 'r') as f:\n            idm_data = json.load(f)\n        with open(fvd_file, 'r') as f:\n            fvd_data = json.load(f)\n        fvd_data = fvd_data[\"mean\"]\n        fvd_data.pop(\"exp_name\", None)\n        idm_data = idm_data[\"metric_mean_on_task\"]\n        metrics_entry = {\n            \"experiment\": exp,\n        }\n        # merge dict \n        metrics_entry.update(fvd_data)\n        metrics_entry.update(idm_data)\n        metrics_list.append(metrics_entry)\n    # Convert list of metrics to a DataFrame\n    df = pd.DataFrame(metrics_list)\n\n    # Save the DataFrame to a CSV file\n    df.to_csv(output_path, index=False)\n    print(f\"[bold red][Tabulating Evaluation Results End][/bold red] Metrics tabulated and saved to {output_path}\")\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Tabulate metrics from JSON files\")\n    parser.add_argument(\"--input_dir\", type=str, required=True, help=\"Directory containing JSON metric files\")\n    parser.add_argument(\"--output_path\", type=str, required=True, help=\"Path to save the tabulated metrics CSV file\")\n    args = parser.parse_args()\n    \n    tabluate_metrics(args.input_dir, args.output_path)"
  },
  {
    "path": "mineworld.py",
    "content": "import os\nimport sys\nsys.path.append(os.getcwd())\nimport gradio as gr\nfrom PIL import Image\nimport numpy as np\nimport torch\nimport cv2\nfrom utils import load_model\nfrom omegaconf import OmegaConf\nfrom argparse import ArgumentParser\nfrom collections import deque\nimport tempfile\nimport atexit\nfrom torchvision import transforms\nfrom einops import rearrange\nfrom mcdataset import MCDataset\nimport itertools\n\nclass Buttons:\n    ATTACK = \"attack\"\n    BACK = \"back\"\n    FORWARD = \"forward\"\n    JUMP = \"jump\"\n    LEFT = \"left\"\n    RIGHT = \"right\"\n    SNEAK = \"sneak\"\n    SPRINT = \"sprint\"\n    USE = \"use\"\n    DROP = \"drop\"\n    SWAPHANDS = \"swapHands\"\n    PICKITEM = \"pickItem\"\n\n    ALL = [\n        ATTACK,\n        USE,\n\n        FORWARD,\n        BACK,\n        LEFT,\n        RIGHT,\n\n        JUMP,\n        SNEAK,\n        SPRINT,\n\n        DROP,\n        SWAPHANDS,\n        PICKITEM,\n\n        # INVENTORY,\n        # ESC,\n    ] + [f\"hotbar.{i}\" for i in range(1, 10)]\n\nKEYBOARD_BUTTON_MAPPING = {\n    \"key.keyboard.s\" :\"back\",\n    \"key.keyboard.q\" :\"drop\",\n    \"key.keyboard.w\" :\"forward\",\n    \"key.keyboard.1\" :\"hotbar.1\",\n    \"key.keyboard.2\" :\"hotbar.2\",\n    \"key.keyboard.3\" :\"hotbar.3\",\n    \"key.keyboard.4\" :\"hotbar.4\",\n    \"key.keyboard.5\" :\"hotbar.5\",\n    \"key.keyboard.6\" :\"hotbar.6\",\n    \"key.keyboard.7\" :\"hotbar.7\",\n    \"key.keyboard.8\" :\"hotbar.8\",\n    \"key.keyboard.9\" :\"hotbar.9\",\n    \"key.keyboard.space\" :\"jump\",\n    \"key.keyboard.a\" :\"left\",\n    \"key.keyboard.d\" :\"right\",\n    \"key.keyboard.left.shift\" :\"sneak\",\n    \"key.keyboard.left.control\" :\"sprint\",\n    \"key.keyboard.f\" :\"swapHands\",\n}\n# Template action\nNOOP_ACTION = {\n    \"forward\": 0,\n    \"back\": 0,\n    \"left\": 0,\n    \"right\": 0,\n\n    \"jump\": 0,\n    \"attack\": 0,\n    \"use\": 0,\n    \"pickItem\": 0,\n\n    \"drop\": 0,\n    \"sneak\": 0,\n    \"sprint\": 0,\n    \"swapHands\": 0,\n\n    \"hotbar.1\": 0,\n    \"hotbar.2\": 0,\n    \"hotbar.3\": 0,\n    \"hotbar.4\": 0,\n    \"hotbar.5\": 0,\n    \"hotbar.6\": 0,\n    \"hotbar.7\": 0,\n    \"hotbar.8\": 0,\n    \"hotbar.9\": 0,\n    \"camera\": np.array([0, 0]),  \n}\n\nACTION_BUTTON = {\n    \"forward\": 0,\n    \"back\": 0,\n    \"left\": 0,\n    \"right\": 0,\n\n    \"attack\": 0,\n    \"sprint\": 0,\n    \"jump\": 0,\n    \"use\": 0,\n\n    \"drop\": 0,\n    \"hotbar.1\": 0,\n    \"pickItem\": 0,\n}\nFOR_BACK = {\n    \"forward\": 0,\n    \"back\": 0,\n}\nL_R = {\n    \"left\": 0,\n    \"right\": 0,   \n}\nATT_USE_DROP = {\n    \"attack\": 0,\n    \"use\": 0,\n    \"drop\": 0,\n}\nJUMP_SPR = {\n    \"jump\": 0,\n    \"sprint\": 0,\n}\nHORBAR = {\n    \"hotbar.1\": 0,\n    \"hotbar.2\": 0,\n    \"hotbar.3\": 0,\n    \"hotbar.4\": 0,\n    \"hotbar.5\": 0,\n    \"hotbar.6\": 0,\n    \"hotbar.7\": 0,\n    \"hotbar.8\": 0,\n    \"hotbar.9\": 0,   \n}\n\n\nsafe_globals = {\"array\": np.array}\n\nAGENT_RESOLUTION = (384, 224)\nCAMERA_SCALER = 360.0 / 2400.0\nTOKEN_PER_IMAGE = 336\nTOKEN_PER_ACTION = 11\nVIDEO_FRAMES = []\nGENERATED_FILES = []\nframe_cache = []\naction_cache = []\nlast_pos = 0\nMC_ACTION_MAP = MCDataset()\nSHOW_FRAMES = 8\nREFERENCE_FRAME = None\nCONTEXT_LEN = None\nDIAGD = False\nWINDOWSIZE = 4\n\ndef get_args():\n    parser = ArgumentParser()\n    parser.add_argument('--scene', type=str, default='./assets/scene.mp4')\n    parser.add_argument('--model_ckpt', type=str, default='./checkpoints/700M_16f.pt')\n    parser.add_argument('--config', type=str, default='./configs/700M_16f.yaml')\n    parser.add_argument('--reference_frame', type=int,  default=8)\n    parser.add_argument('--diagd', action='store_true', help='use diagd')\n    parser.add_argument('--window_size', type=int,  default=4)\n    args = parser.parse_args()\n    return args\n\ndef make_action_dict(action_line):\n    action_dict = {'ESC': 0, 'back': 0, 'drop': 0, 'forward': 0, 'hotbar.1': 0, 'hotbar.2': 0, 'hotbar.3': 0, 'hotbar.4': 0, 'hotbar.5': 0, 'hotbar.6': 0, 'hotbar.7': 0, 'hotbar.8': 0, 'hotbar.9': 0, 'inventory': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0, 'swapHands': 0, 'camera': np.array([0, 0]), 'attack': 0, 'use': 0, 'pickItem': 0}\n    if isinstance(action_line, str):\n        action_line = action_line.split(\",\")\n        action_dict['camera'] = np.array((int(action_line[-2]), int(action_line[-1])))\n    for act in action_line:\n        if act in Buttons.ALL:\n            action_dict[act] = 1\n    \n    return action_dict\n\ndef stack_images(imgs):\n    width, height = imgs[0].size\n    new_im = Image.new('RGB', (4*width, height*2))\n    for i, im in enumerate(imgs):\n        new_im.paste(im, (width*(i%4), height*(i//4)))\n    return new_im\n\ndef get_action_line(acts):\n    action_lst = []\n    for k in acts.keys():\n        if k != \"camera\" and acts[k] == 1:\n            action_lst.append(k)\n    action_lst.append(str(acts[\"camera\"][0]))\n    action_lst.append(str(acts[\"camera\"][1]))    \n\n    return \",\".join(action_lst)\n\ndef run_prediction(btns_choices, cam_x_input, cam_y_input):\n    global frame_cache, action_cache, actions_show, images_show, VIDEO_FRAMES, last_pos, CONTEXT_LEN, REFERENCE_FRAME    \n    assert len(frame_cache) == len(action_cache)+1\n    if len(action_cache) >= CONTEXT_LEN - 1:\n        for _ in range(CONTEXT_LEN - REFERENCE_FRAME):\n            frame_cache.popleft()\n            action_cache.popleft() \n        model.transformer.refresh_kvcache()\n        _frame_iter = itertools.islice(frame_cache, 0, len(frame_cache)-1)\n        _act_iter = itertools.islice(action_cache, 0, len(action_cache))\n        _vis_act = [\n                torch.cat([img, act], dim=1)\n                for img, act in zip(_frame_iter, _act_iter)\n            ]\n        _vis_act.append(frame_cache[-1])\n        _vis_act = torch.cat(_vis_act, dim=-1)\n        _, last_pos = model.transformer.prefill_for_gradio(_vis_act)\n        \n\n    act_dict = make_action_dict(btns_choices)\n    act_dict['camera'] = np.array((int(cam_y_input), int(cam_x_input)))\n    ongoing_act = MC_ACTION_MAP.get_action_index_from_actiondict(act_dict, action_vocab_offset=8192)\n    ongoing_act = torch.tensor(ongoing_act).unsqueeze(0).to(\"cuda\")\n    action_cache.append(ongoing_act)\n    with torch.no_grad(), torch.autocast(\"cuda\", dtype=torch.float16):\n        if DIAGD:\n            next_frame, last_pos = model.transformer.diagd_img_token_for_gradio(input_action=ongoing_act, position_id = last_pos, max_new_tokens=TOKEN_PER_IMAGE, windowsize=4)\n        else:\n            next_frame, last_pos = model.transformer.decode_img_token_for_gradio(input_action=ongoing_act, position_id = last_pos, max_new_tokens=TOKEN_PER_IMAGE + 1) # +1 to fill kvcache\n        \n    last_pos = last_pos[0]\n    next_frame = torch.cat(next_frame, dim=-1).to(\"cuda\")\n    frame_cache.append(next_frame)\n    next_frame = tokenizer.token2image(next_frame)\n    next_frame = Image.fromarray(next_frame)\n    if len(images_show) >= SHOW_FRAMES:\n        images_show.popleft()\n        actions_show.popleft()\n    btns_choices = btns_choices + [np.array((int(cam_y_input), int(cam_x_input)))]\n    actions_show.append(','.join(str(x) for item in btns_choices for x in (item if isinstance(item, np.ndarray) else [item])))\n    images_show.append(next_frame)\n    VIDEO_FRAMES.append(next_frame)\n\n    return next_frame, stack_images(images_show), \"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;\".join([str(x) for x in actions_show])\n\ndef run_prediction_n_times(n, btns_1, btns_2, btns_3, btns_4, btns_5, cam_x_input, cam_y_input):\n    btns_choices = btns_1 + btns_2 + btns_3 + btns_4 + btns_5\n    if cam_x_input is None:\n        cam_x_input = 0\n    if cam_y_input is None:\n        cam_y_input = 0\n    if n is None:\n        n = 1\n    for i in range(n):\n        yield run_prediction(btns_choices, cam_x_input, cam_y_input)\n\ndef step_pred_source_video_right(video_path, start):\n    global VIDEO_FRAMES, frame_cache, action_cache, REFERENCE_FRAME, CONTEXT_LEN, REFERENCE_FRAME\n    VIDEO_FRAMES.clear(); frame_cache.clear(); action_cache.clear()\n    if start is None or start < 0 or start > MAX_FRAME:\n        start = 0\n    return step_video(video_path, start, REFERENCE_FRAME)\n\ndef on_download_button_click(fps=6):\n    if not VIDEO_FRAMES:\n        print(\"The frames list is empty.\")\n        return\n\n    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=\".mp4\", dir=\"/tmp\")\n    video_path = temp_file.name\n    temp_file.close()\n\n    video_writer = cv2.VideoWriter(\n        video_path,\n        cv2.VideoWriter_fourcc(*\"mp4v\"),\n        fps,\n        AGENT_RESOLUTION\n    )\n\n    for frame in VIDEO_FRAMES:\n        frame_np = np.array(frame)\n        frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)\n        video_writer.write(frame_bgr)\n\n    video_writer.release()\n    GENERATED_FILES.append(video_path)\n    os.chmod(video_path, 0o644)\n    return video_path\n\ndef cleanup_files():\n    for video_path in GENERATED_FILES:\n        try:\n            os.remove(video_path)\n            print(f\"Deleted file: {video_path}\")\n        except OSError as e:\n            print(f\"Error deleting file {video_path}: {e}\")\n\natexit.register(cleanup_files)\n\ndef step_video(video_path, start_frame, frame_count):\n    global images_show, actions_show, frame_cache, action_cache, VIDEO_FRAMES, last_pos, CONTEXT_LEN, REFERENCE_FRAME\n    VIDEO_FRAMES = []\n    images_show = []\n    actions_show = []\n    video = cv2.VideoCapture(video_path)\n    json_data = MC_ACTION_MAP.read_jsonl(video_path[:-4]+\".jsonl\")\n\n    frames_tensor = []\n    action_cache = []\n    for i in range(start_frame, start_frame + frame_count):\n        step_action = json_data[i]\n        step_action, _ = MC_ACTION_MAP.json_action_to_env_action(step_action)\n        actions_show.append(get_action_line(step_action))\n        \n        act_index = MC_ACTION_MAP.get_action_index_from_actiondict(step_action, action_vocab_offset=8192)\n        act_index = torch.tensor(act_index).unsqueeze(0)\n        action_cache.append(act_index.to(\"cuda\"))\n\n        video.set(cv2.CAP_PROP_POS_FRAMES, i)\n        ret, frame = video.read()\n        try:\n            if not ret:\n                raise ValueError(f\"frame {i} not ret\")\n            cv2.cvtColor(frame, code=cv2.COLOR_BGR2RGB, dst=frame)\n\n            frame = np.asarray(np.clip(frame, 0, 255), dtype=np.uint8)\n            frame = cv2.resize(frame, AGENT_RESOLUTION, interpolation=cv2.INTER_LINEAR)\n            images_show.append(Image.fromarray(frame))\n            VIDEO_FRAMES.append(Image.fromarray(frame))\n            frames_tensor.append(torch.from_numpy(frame))\n            \n        except Exception as e:\n            print(f\"Could not read frame from video {video_path}: {e}\")\n    video.release()\n    frames_tensor = torch.stack(frames_tensor, dim=0).to(\"cuda\")\n    frames_tensor = frames_tensor.permute(0, 3, 1, 2).float() / 255.0\n    frames_tensor = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(frames_tensor)\n\n    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16):\n        images_token = tokenizer.tokenize_images(frames_tensor)\n    images_token = rearrange(images_token, '(b t) h w -> (b t) (h w)', b=1)\n    \n    frame_cache = deque(torch.split(images_token, split_size_or_sections=1, dim=0))\n    action_cache = deque(action_cache)\n    action_cache.pop()\n\n    images_show = deque(images_show)\n    actions_show = deque(actions_show)\n\n    actions_show.pop()\n\n    \n    model.transformer.refresh_kvcache()\n    _frame_iter = itertools.islice(frame_cache, 0, len(frame_cache)-1)\n    _act_iter = itertools.islice(action_cache, 0, len(action_cache))\n\n    \n    _vis_act = [\n            torch.cat([img, act], dim=1)\n            for img, act in zip(_frame_iter, _act_iter)\n        ]\n    _vis_act.append(frame_cache[-1])\n    _vis_act = torch.cat(_vis_act, dim=-1)\n\n    _, last_pos = model.transformer.prefill_for_gradio(_vis_act)\n\n    while len(images_show) > SHOW_FRAMES:\n        images_show.popleft()\n        actions_show.popleft()\n        # WARNING: why dont pop actions\n    \n    return stack_images(images_show), \"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;\".join([str(x) for x in actions_show]), None\n\n\ncss = \"\"\"\n.custom-tab h2 { \n    font-size: 34px;  /* 字体大小 */\n    font-weight: bold;  /* 加粗字体 */\n    color: #ff6600;  /* 字体颜色 */\n    text-shadow: 1px 1px 2px #000000; /* 文字阴影效果 */\n}\n\"\"\"\n\nif __name__ == \"__main__\":\n    args = get_args()\n    if args.diagd:\n        DIAGD = True\n        WINDOWSIZE = args.window_size\n    cap = cv2.VideoCapture(args.scene)\n    global MAX_FRAME\n    MAX_FRAME = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 10\n\n    config = OmegaConf.load(args.config)\n    REFERENCE_FRAME = args.reference_frame\n    CONTEXT_LEN = int(config.model.params.transformer_config.params.max_position_embeddings / (TOKEN_PER_ACTION + TOKEN_PER_IMAGE))\n    assert CONTEXT_LEN > REFERENCE_FRAME\n    model = load_model(config, args.model_ckpt, gpu=True, eval_mode=True)\n    tokenizer = model.tokenizer\n    with gr.Blocks(css=css) as demo:\n        with gr.Tab(\"MineWorld\", elem_classes=\"custom-tab\"):\n            source_video_path = gr.Text(value=args.scene, visible=False)\n            with gr.Row():\n                source_video_actions = gr.Markdown(visible=False)\n                instruction = gr.Markdown(\"press 'Jump to start frame' to init or restart the game, you can choose different sences by modifed start frame from 0 to 4100\", visible=False)\n            with gr.Row():\n                source_video_images = gr.Image(width=1280, height=360, label=\"last 8 frames\", show_fullscreen_button = True, every=1)\n\n            with gr.Row(equal_height=True):                    \n                with gr.Column(min_width=60):\n                    vid_frame_start = gr.Number(step=1, value=0, info=\"start frame\", min_width=20, show_label=False, minimum=0, maximum=MAX_FRAME)\n                        # vid_num_frames = gr.Number(step=1, value=4, label=\"num_frames\", min_width=50)\n                # with gr.Column(min_width=60):\n                    run_steps = gr.Number(step=1, value=1, info=\"Repeat same action n times\", min_width=20, minimum=1, maximum=8, show_label=False)\n                with gr.Column(min_width=60):\n                    btn1 = list(FOR_BACK.keys())\n                    btns_1 = gr.CheckboxGroup(choices=btn1, show_label=False)\n                    vid_right_btn = gr.Button(value=\"Jump to start frame\", size='sm')\n                with gr.Column(min_width=60):\n                    btn2 = list(L_R.keys())\n                    btns_2 = gr.CheckboxGroup(choices=btn2, show_label=False)\n                    predict_run_btn = gr.Button(value=\"Run\", variant=\"primary\", size='sm')\n                with gr.Column(min_width=60):\n                    btn4 = list(JUMP_SPR.keys())\n                    btns_4 = gr.CheckboxGroup(choices=btn4, show_label=False)\n                    download_game_btn = gr.Button(\"Generate Video\", size='sm')\n                with gr.Column(min_width=60):\n                    cam_y_input = gr.Number(step=1, value=0, info=\"camera Y ⬆️(-),0,(+)⬇️\", min_width=20, minimum=-90, maximum=90, show_label=False)\n                    cam_x_input = gr.Number(step=1, value=0, info=\"camera X ⬅️(-),0,(+)➡️\", min_width=20, minimum=-90, maximum=90, show_label=False)\n                \n                    \n            with gr.Row():\n                with gr.Column(min_width=250):\n                    video_display = gr.Video(label=\"video\", width=384, height=224)\n                with gr.Column(min_width=200): \n                    predict_result_imgs = gr.Image(label=\"last generated frame\",width=384, height=224)\n                with gr.Column(min_width=200):\n                    with gr.Row():\n                        btn3 = list(ATT_USE_DROP.keys())\n                        btns_3 = gr.CheckboxGroup(choices=btn3, show_label=False)\n                    with gr.Row():\n                        btn5 = list(HORBAR.keys())\n                        btns_5 = gr.CheckboxGroup(choices=btn5, show_label=False)\n\n                    \n                \n            vid_right_btn.click(fn=step_pred_source_video_right, inputs=[source_video_path, vid_frame_start],\n                                outputs=[source_video_images, source_video_actions, predict_result_imgs])                \n            predict_run_btn.click(fn=run_prediction_n_times, inputs=[run_steps, btns_1, btns_2, btns_3, btns_4, btns_5, cam_x_input, cam_y_input],\n                                    outputs=[predict_result_imgs, source_video_images, source_video_actions],)\n            download_game_btn.click(fn=on_download_button_click, inputs=[], outputs=video_display)\n            demo.load(fn=step_pred_source_video_right, inputs=[source_video_path, gr.Number(value=25, visible=False)],\n                                outputs=[source_video_images, source_video_actions, predict_result_imgs])\n    demo.queue()\n    \n    demo.launch(server_name=\"0.0.0.0\", max_threads=256, server_port=7861, share=True)\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch==2.6.0\ntorchvision==0.21.0\nomegaconf==2.3.0\ntransformers==4.48.1\nopencv-python==4.11.0.86\nattrs==25.3.0\ndiffusers==0.32.2\ngradio==5.24.0\neinops==0.8.1\ndiffusers==0.32.2\nscipy==1.15.2\ntorch-fidelity==0.3.0\ngym3==0.3.3\ngym==0.26.2\nscikit-learn==1.6.1"
  },
  {
    "path": "scripts/compute_metrics.sh",
    "content": "#!/bin/bash\n\nVIDEO_RESULTS_ROOT_DEFAULT=\"videos\"\nMETRICS_ROOT_DEFAULT=\"metrics_log\"\nJSONL_PATH_DEFAULT=\"validation/validation\"\nIDM_CKPT_DIR=\"checkpoints/IDM\"\n\nVIDEO_RESULTS_ROOT=${1:-$VIDEO_RESULTS_ROOT_DEFAULT}\nMETRICS_ROOT=${2:-$METRICS_ROOT_DEFAULT}\nJSONL_PATH=${3:-$JSONL_PATH_DEFAULT}\n\necho \"VIDEO_RESULTS_ROOT = $VIDEO_RESULTS_ROOT\"\necho \"METRICS_ROOT       = $METRICS_ROOT\"\necho \"JSONL_PATH         = $JSONL_PATH\"\n\n# Loop through each subdirectory in VIDEO_RESULTS_ROOT\nfor video_dir1 in \"$VIDEO_RESULTS_ROOT\"/*/; do\n    # Skip the 'metrics' directory\n    if [ -d \"$video_dir1\" ] && [ \"$(basename \"$video_dir1\")\" != \"metrics\" ]; then\n        # Construct the output file name based on the video directory name\n        fvd_output_file=\"$METRICS_ROOT/fvd_$(basename \"$video_dir1\").json\"\n        echo $fvd_output_file\n        # Run the python command for each video directory\n        python metrics/common_metrics.py --video_dir2 $JSONL_PATH --video_length 15 --channel 3 --size \"(224,384)\" \\\n                --video_dir1 \"$video_dir1\" --output-file \"$fvd_output_file\"\n\n        idm_output_file=\"$METRICS_ROOT/idm_$(basename \"$video_dir1\").json\"\n        python metrics/IDM/inverse_dynamics_model.py --weights $IDM_CKPT_DIR/\"4x_idm.weights\" \\\n            --infer-demo-num 1 --n-frames 15 \\\n            --model $IDM_CKPT_DIR/\"4x_idm.model\" --video-path $video_dir1 \\\n            --output-file \"$idm_output_file\" \\\n            --jsonl-path $JSONL_PATH\n    fi\ndone\n\npython  metrics/tabulate_all_results.py --input_dir $METRICS_ROOT --output_path $METRICS_ROOT/latest_metrics.csv"
  },
  {
    "path": "scripts/inference_16f_models.sh",
    "content": "DATA_ROOT=\"validation/validation\"\n###########################\n#### Inference 1200M models \n###########################\nCONFIG=\"configs/1200M_16f.yaml\"\nCKPT_PATH=\"checkpoints/1200M_16f.ckpt\"\nOUTPUT_PATH=\"./videos/1200M_16f200_demo1gen15_naive\"\npython inference.py \\\n        --data_root $DATA_ROOT  \\\n        --config $CONFIG \\\n        --model_ckpt $CKPT_PATH \\\n        --demo_num 1 --frames 15  \\\n        --accelerate-algo 'naive' \\\n        --top_p 0.8 \\\n        --output_dir  $OUTPUT_PATH"
  },
  {
    "path": "scripts/setup_metrics.sh",
    "content": "# clone common_metrics_on_video_quality repository\ngit clone git@github.com:CIntellifusion/common_metrics_on_video_quality.git \n\n# get IDM weights \nmkdir -p checkpoints/IDM\nwget https://openaipublic.blob.core.windows.net/minecraft-rl/idm/4x_idm.model -O checkpoints/IDM/4x_idm.model\nwget https://openaipublic.blob.core.windows.net/minecraft-rl/idm/4x_idm.weights  -O checkpoints/IDM/4x_idm.weights\n"
  },
  {
    "path": "utils.py",
    "content": "import torch\nimport importlib\nfrom rich import print\nfrom typing import Union\nimport numpy as np\nimport os\n\ndef print0(*args, **kwargs):\n    print(*args, **kwargs)  # python -m rich.color\n\ndef tensor_to_uint8(tensor):\n    tensor = torch.clamp(tensor, -1.0, 1.0)\n    tensor = (tensor + 1.0) / 2.0\n    tensor = (tensor.cpu().numpy() * 255).astype(np.uint8)\n    return tensor\n\ndef get_obj_from_str(string, reload=False):\n    module, cls = string.rsplit(\".\", 1)\n    if reload:\n        module_imp = importlib.import_module(module)\n        importlib.reload(module_imp)\n    return getattr(importlib.import_module(module, package=None), cls)\n\ndef instantiate_from_config(config):\n    if not \"target\" in config:\n        raise KeyError(\"Expected key `target` to instantiate.\")\n    return get_obj_from_str(config[\"target\"])(**config.get(\"params\", dict()))\n\ndef load_model_from_config(config, sd, gpu=True, eval_mode=True):\n    model = instantiate_from_config(config)\n    if sd is not None:\n        missing, unexpected = model.load_state_dict(sd, strict=False)\n        if len(missing) != 0:\n            raise ValueError(f\"Missing keys: {missing}\")\n    if gpu:\n        model.cuda()\n    if eval_mode:\n        model.eval()\n    return {\"model\": model}\n\ndef get_valid_dirs(dir1: str, dir2: Union[None, str] = None, dir3: Union[None, str] = None) -> Union[None, str]:\n    if (dir1 is not None) and os.path.isdir(dir1): return dir1\n    elif (dir2 is not None) and os.path.isdir(dir2): return dir2\n    elif (dir3 is not None) and os.path.isdir(dir3): return dir3\n    else: return None\n\ndef get_valid_paths(path1: str, path2: Union[None, str] = None, path3: Union[None, str] = None) -> Union[None, str]:\n    if (path1 is not None) and os.path.isfile(path1): return path1\n    elif (path2 is not None) and os.path.isfile(path2): return path2\n    elif (path3 is not None) and os.path.isfile(path3): return path3\n    else: return None\n\ndef load_model(config, ckpt, gpu, eval_mode):\n    if str(ckpt).endswith(\".bin\"):\n        weight = torch.load(ckpt)\n    elif ckpt:\n        weight = torch.load(ckpt, map_location=\"cpu\")[\"state_dict\"]\n    model = load_model_from_config(config.model, weight, gpu=gpu, eval_mode=eval_mode)[\"model\"]\n    model.load_state_dict(weight, strict=False)\n    model.to(torch.float16)\n    \n    return model\n\n"
  },
  {
    "path": "vae.py",
    "content": "import torch\nimport torch.nn as nn\nimport diffusers\nfrom safetensors.torch import load_file as load_safetensors\nfrom utils import print0, get_valid_paths, tensor_to_uint8\n\n\nclass VAE(nn.Module):\n    def __init__(self,\n                 config_path: str,\n                 ckpt_path: str,\n    ):\n        super().__init__()\n        config_path = get_valid_paths(config_path)\n        print0(f\"[bold magenta]\\[VAE][/bold magenta] Loading VQGAN from {config_path}\")\n        self.model = diffusers.VQModel.from_config(config_path)\n\n        ckpt_path = get_valid_paths(ckpt_path)\n        print0(f\"[bold magenta]\\[VAE][/bold magenta] Use ckpt_path: {ckpt_path}\")\n        self.init_from_ckpt(ckpt_path)\n\n    def init_from_ckpt(\n        self, path: str\n    ) -> None:\n        if path.endswith(\"ckpt\"):\n            ckpt = torch.load(path, map_location=\"cpu\", weights_only=False)\n            if \"state_dict\" in ckpt:\n                weights = ckpt[\"state_dict\"]\n            else:\n                weights = ckpt\n        elif path.endswith(\"safetensors\"):\n            weights = load_safetensors(path)\n        else:\n            raise NotImplementedError\n\n        missing, unexpected = self.load_state_dict(weights, strict=False)\n        print0(\n            f\"[bold magenta]\\[tvae.models.amused_vqvae][AutoencodingLegacy][/bold magenta] Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys\"\n        )\n        if len(missing) > 0:\n            print0(f\"[bold magenta]\\[tvae.models.amused_vqvae][AutoencodingLegacy][/bold magenta] Missing Keys: {missing}\")\n        # if len(unexpected) > 0:\n        #     print0(f\"[bold magenta]\\[tvae.models.amused_vqvae][AutoencodingLegacy][/bold magenta] Unexpected Keys: {unexpected}\")\n\n    @torch.no_grad()\n    def tokenize_images(self, x: torch.Tensor, sane_index_shape: bool = True):\n        h = self.model.encoder(x)\n        h = self.model.quant_conv(h)\n        if sane_index_shape:\n            orig_sane_index_shape = self.model.quantize.sane_index_shape\n            self.model.quantize.sane_index_shape = True\n        z_q, loss, (perplexity, min_encodings, min_encoding_indices) = self.model.quantize(h)\n        if sane_index_shape:\n            self.model.quantize.sane_index_shape = orig_sane_index_shape\n        return min_encoding_indices\n    \n    # yang ye\n    @torch.no_grad()\n    def token2image(self, tokens):\n        assert tokens.max() < 8192, f\"code max value is {tokens.max()}\"\n        shape = (1, 14, 24, 64) \n        with torch.autocast(device_type='cuda', dtype=torch.float32):\n            quant = self.model.quantize.get_codebook_entry(tokens, shape)\n            quant2 = self.model.post_quant_conv(quant)\n            dec = self.model.decoder(quant2)\n        img = tensor_to_uint8(dec[0]).transpose(1, 2, 0)\n        return img"
  }
]