Repository: microsoft/mineworld Branch: main Commit: 9f49efbcc68a Files: 40 Total size: 243.3 KB Directory structure: gitextract_mrwf1vu0/ ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── configs/ │ ├── 1200M_16f.yaml │ ├── 1200M_32f.yaml │ ├── 300M_16f.yaml │ ├── 700M_16f.yaml │ └── 700M_32f.yaml ├── diagonal_decoding.py ├── inference.py ├── lvm.py ├── mcdataset.py ├── metrics/ │ ├── IDM/ │ │ ├── inverse_dynamics_model.py │ │ └── lib/ │ │ ├── __init__.py │ │ ├── action_head.py │ │ ├── action_mapping.py │ │ ├── actions.py │ │ ├── impala_cnn.py │ │ ├── masked_attention.py │ │ ├── minecraft_util.py │ │ ├── misc.py │ │ ├── mlp.py │ │ ├── normalize_ewma.py │ │ ├── policy.py │ │ ├── scaled_mse_head.py │ │ ├── torch_util.py │ │ ├── tree_util.py │ │ ├── util.py │ │ └── xf.py │ ├── common_metrics.py │ └── tabulate_all_results.py ├── mineworld.py ├── requirements.txt ├── scripts/ │ ├── compute_metrics.sh │ ├── inference_16f_models.sh │ └── setup_metrics.sh ├── utils.py └── vae.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Microsoft Open Source Code of Conduct This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). Resources: - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) Microsoft Corporation. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE ================================================ FILE: README.md ================================================
# MineWorld
A Real-time Interactive World Model on Minecraft [![arXiv](https://img.shields.io/badge/arXiv-Paper-red?logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2504.08388)   [![Project](https://img.shields.io/badge/Project-Page-blue?logo=homepage&logoColor=white)](https://aka.ms/mineworld)   [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-yellow)](https://huggingface.co/microsoft/mineworld)
We introduce MineWorld, an interactive world model on Minecraft that brings several key advancements over existing approaches: * 🕹️ **High generation quality**. Built on a visual-action autoregressive Transformer, MineWorld generates coherent, high-fidelity frames conditioned on both visuals and actions. * 🕹️ **Strong controllability**. We propose benchmarks for the action-following capacity, where MineWorld shows precise and consistent behavior. * 🕹️ **Fast inference speed**. With Diagonal Decoding, MineWorld achieves a generation rate of 4 to 7 frames per second, enabling real-time interaction in open-ended game environments. https://github.com/user-attachments/assets/2f5b4740-badd-453c-970d-061abd367f82 ## 🔥 News * May, 2025: The model checkpoints in the [Huggingface repo](https://huggingface.co/microsoft/mineworld) have been temporally taken down. * April, 2025: 🚀 [MineWorld](https://github.com/microsoft/mineworld) was released! * March, 2025: 🚀 The paper of [Diagonal Decoding](https://arxiv.org/pdf/2503.14070) was released! ## 🔧 Setup 1. Clone this repository and navigate to MineWorld folder: ```bash git clone https://github.com/microsoft/mineworld.git cd mineworld ``` 2. We provide an `requirements.txt` file for setting up a pip environment. ```bash # 1. Prepare conda environment conda create -n mineworld python=3.10 # 2. Activate the environment conda activate mineworld # 3. install our environment pip3 install -r requirements.txt ``` We recommend using high-end GPU for inference. We have done all testing and development using A100 and H100 GPU. ## 🎈 Checkpoints Download pre-trained models [here](https://huggingface.co/microsoft/mineworld). Each checkpoint has a corresponding config file with the same name in the `configs` folder in this repository. All models share the same vae checkpoint and config. The data structure is as follows: ``` └── checkpoints ├── 300M_16f.ckpt ├── 700M_16f.ckpt ├── 700M_32f.ckpt ├── 1200M_16f.ckpt └── 1200M_32f.ckpt └── vae ├── config.json └── vae.ckpt └── validation └── validation.zip └── gradio_scene ├── scene.mp4 └── scene.jsonl ``` ## 🚀 Inference We provide two ways to use our model: interacting with it in a web demo, and running locally to reproduce the evaluation results in our paper. In addition to download the checkpoints and place them in the `checkpoints` folder, it is also required to download `scene.mp4` and `scene.jsonl` when running the web demo. Make sure they are placed in the same directory. ### Run Web Demo To launch the webpage game, run the following command: ```bash python mineworld.py --scene "path/to/scene.mp4" --model_ckpt "path/to/ckpt" --config "path/to/config" ``` ![image](assets/demo.png) Once the demo is running, you can access the website through the local URL or the public URL displayed in the command line. Initialization and the first action may take some time due to compilation. You can specify a reference frame using the `--reference_frame` option, which should be larger than `4` and smaller than the context length of the model (i.e., `16` or `32` depending on the model utilized). A higher reference frame number generally corresponds to better visual quality. Once the initial state has been set, perform the game actions by selecting options in each chatbox. The game progresses when pressing the "Run" button, displaying the last `8` frames and the most recent frame separately. Players can also set an action count to repeat an action multiple times. Explanations to the buttons in the web demo are as follows: ``` Start frame: select a frame in scene.mp4 with its frame index Jump to start frame: use the selected frame as the initial state Camera `X` and `Y`: control the camera movements between `-90` and `90` degrees Other action buttons: same as the actions in Minecraft Generate video: save previous game progress ``` ### Run Local Inference To run inference locally, use the following command: ```bash python inference.py \ --data_root "/path/to/validation/dataset" \ --model_ckpt "path/to/ckpt" \ --config "path/to/config" \ --demo_num 1 \ --frames 15 \ --accelerate-algo 'naive' \ --top_p 0.8 \ --output_dir "path/to/output" ``` Check `scripts/inference_16f_models.sh` for examples. To switch between naive autoregressive decoding and diagonal decoding, change the command `--accelerate-algo` to `naive` and `image_diagd` correspondingly. After the inference of a set of videos, you can compute the metrics and reproduce the numerical results in our paper, check and run the following scripts: ```bash bash scripts/setup_metrics.sh # only required in the first time bash scripts/compute_metrics.sh ``` The evalution outputs will have the following structure: ``` └── videos ├── inference_setting1 ├── clip_1.mp4 └── clip_1.json ├── inference_setting2 ├── clip_1.mp4 └── clip_1.json └── metrics_log ├── fvd_inference_setting1.json ├── fvd_inference_setting2.json ├── idm_inference_setting1.json ├── idm_inference_setting2.json └── latest_metrics.csv ``` All results will be aggregated into `metrics_log/latest_metrics.csv`. ## 💡 Intended Uses Our model is solely trained in the Minecraft game domain. As a world model, an initial image in the game scene will be provided, and the users should select an action from the action list. Then the model will generate the next scene that takes place the selected action. ## 🪧 Out-of-scope Uses Our models are not specifically designed for any tasks or scenarios other than the Minecraft model. Developers should expect failures in generation results regarding the out-of-scope scenarios. Developers should be aware of and adhere to applicable laws or regulations (including privacy, trade compliance laws, etc.) that are relevant to their use case, and evaluate and mitigate for privacy, safety, and fairness before using within a specific downstream use case, particularly for high-risk scenarios. ## 🤖️ Risks and Limitations Some of the limitations of this model to be aware of include: * Quality of Service: MineWorld is trained solely on Minecraft, so it cannot generate results for other video domains (such as internet video). And the model cannot generate videos with higher resolution. * Information Reliability: MineWorld is trained on videos with a fixed resolution, therefore the results may lose detailed information due to the low resolution. * MineWorld inherits any biases, errors, or omissions characteristic of its training data, which may be amplified by any AI-generated interpretations. * MineWorld was developed for research and experimental purposes. Further testing and validation are needed before considering its application in commercial or real-world scenarios. * The input of other images than Minecraft will result in incoherent imagery being created and should not be attempted. * Users are responsible for sourcing their datasets legally and ethically. This could include securing appropriate copy rights, ensuring consent for use of audio/images, and/or the anonymization of data prior to use in research. ## ✏️ BibTeX ```bibtex @article{guo2025mineworld, title={MineWorld: a Real-Time and Open-Source Interactive World Model on Minecraft}, author={Guo, Junliang and Ye, Yang and He, Tianyu and Wu, Haoyu and Jiang, Yushu and Pearce, Tim and Bian, Jiang} year={2025}, journal={arXiv preprint arXiv:2504.08388}, } ``` ## 🤗 Acknowledgments This codebase borrows code from [VPT](https://github.com/openai/Video-Pre-Training) and [generative-models](https://github.com/Stability-AI/generative-models). We thank them for their efforts and innovations, which have made the development process more efficient and convenient. Thank you to everyone who contributed their wisdom and efforts to this project. ## ☎️ Contact We welcome feedback and collaboration from our audience. If you have suggestions, questions, or observe unexpected/offensive behavior in our technology, please contact us through `tianyuhe AT microsoft.com`. ## 📄 Contributing This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA. This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. ## 📍 Trademarks This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies. ================================================ FILE: SECURITY.md ================================================ ## Security Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. ## Reporting Security Issues **Please do not report security vulnerabilities through public GitHub issues.** Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) * Full paths of source file(s) related to the manifestation of the issue * The location of the affected source code (tag/branch/commit or direct URL) * Any special configuration required to reproduce the issue * Step-by-step instructions to reproduce the issue * Proof-of-concept or exploit code (if possible) * Impact of the issue, including how an attacker might exploit the issue This information will help us triage your report more quickly. If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. ## Preferred Languages We prefer all communications to be in English. ## Policy Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). ================================================ FILE: SUPPORT.md ================================================ # TODO: The maintainer of this repo has not yet edited this file **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? - **No CSS support:** Fill out this template with information about how to file issues and get help. - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* # Support ## How to file issues and get help This project uses GitHub Issues to track bugs and feature requests. Please search the existing issues before filing new issues to avoid duplicates. For new issues, file your bug or feature request as a new Issue. For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER CHANNEL. WHERE WILL YOU HELP PEOPLE?**. ## Microsoft Support Policy Support for this **PROJECT or PRODUCT** is limited to the resources listed above. ================================================ FILE: configs/1200M_16f.yaml ================================================ model: target: lvm.LlamaLVM params: model_class: lvm.LlamaForCausalLM tokenizer_config: target: vae.VAE params: config_path: checkpoints/vae/config.json ckpt_path: checkpoints/vae/vae.ckpt transformer_config: target: transformers.LlamaConfig params: max_position_embeddings: 5552 hidden_size: 2048 intermediate_size: 8192 num_attention_heads: 32 num_key_value_heads: 4 num_hidden_layers: 20 rope_theta: 10000.0 torch_dtype: bfloat16 rms_norm_eps: 1.0e-05 vocab_size: 8262 attention_bias: false mlp_bias: false ================================================ FILE: configs/1200M_32f.yaml ================================================ model: target: lvm.LlamaLVM params: model_class: lvm.LlamaForCausalLM tokenizer_config: target: vae.VAE params: config_path: checkpoints/vae/config.json ckpt_path: checkpoints/vae/vae.ckpt transformer_config: target: transformers.LlamaConfig params: max_position_embeddings: 11104 hidden_size: 2048 intermediate_size: 8192 num_attention_heads: 32 num_key_value_heads: 4 num_hidden_layers: 20 rope_theta: 10000.0 torch_dtype: bfloat16 rms_norm_eps: 1.0e-05 vocab_size: 8262 attention_bias: false mlp_bias: false ================================================ FILE: configs/300M_16f.yaml ================================================ model: target: lvm.LlamaLVM params: model_class: lvm.LlamaForCausalLM tokenizer_config: target: vae.VAE params: config_path: checkpoints/vae/config.json ckpt_path: checkpoints/vae/vae.ckpt transformer_config: target: transformers.LlamaConfig params: max_position_embeddings: 5552 hidden_size: 1024 intermediate_size: 4096 num_attention_heads: 16 num_key_value_heads: 4 num_hidden_layers: 20 initializer_range: 0.02 rope_theta: 10000.0 torch_dtype: bfloat16 rms_norm_eps: 1.0e-05 vocab_size: 8262 attention_bias: false mlp_bias: false token_num: 347 image_num: 336 frame: 16 ================================================ FILE: configs/700M_16f.yaml ================================================ model: target: lvm.LlamaLVM params: model_class: lvm.LlamaForCausalLM tokenizer_config: target: vae.VAE params: config_path: checkpoints/vae/config.json ckpt_path: checkpoints/vae/vae.ckpt transformer_config: target: transformers.LlamaConfig params: max_position_embeddings: 5552 hidden_size: 2048 intermediate_size: 4096 num_attention_heads: 32 num_key_value_heads: 4 num_hidden_layers: 20 rope_theta: 10000.0 torch_dtype: bfloat16 rms_norm_eps: 1.0e-05 vocab_size: 8262 attention_bias: false mlp_bias: false ================================================ FILE: configs/700M_32f.yaml ================================================ model: target: lvm.LlamaLVM params: model_class: lvm.LlamaForCausalLM tokenizer_config: target: vae.VAE params: config_path: checkpoints/vae/config.json ckpt_path: checkpoints/vae/vae.ckpt transformer_config: target: transformers.LlamaConfig params: max_position_embeddings: 11104 hidden_size: 2048 intermediate_size: 4096 num_attention_heads: 32 num_key_value_heads: 4 num_hidden_layers: 20 rope_theta: 10000.0 torch_dtype: bfloat16 rms_norm_eps: 1.0e-05 vocab_size: 8262 attention_bias: false mlp_bias: false ================================================ FILE: diagonal_decoding.py ================================================ import torch from typing import Optional from torch.nn.attention import SDPBackend def sample_top_k(logits, temperature: float = 1.0, top_k: Optional[int] = None, vocab_size=8192): """ Sample from the logits using top-k sampling. Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py """ # logits: [batch_size, seq_len, vocab_size] if temperature == 0.0: idx_next = torch.argmax(logits[:, -1, :vocab_size], dim=-1, keepdim=True) else: probs = logits_to_probs(logits[:, -1, :vocab_size], temperature, top_k) idx_next = multinomial_sample_one_no_sync(probs) return idx_next def multinomial_sample_one_no_sync(probs_sort, dtype=torch.int): """ Multinomial sampling without a cuda synchronization. Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py """ q = torch.empty_like(probs_sort).exponential_(1) return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=dtype) def logits_to_probs( logits, temperature: float = 1.0, top_k: Optional[int] = None, ): logits = logits / max(temperature, 1e-5) if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) pivot = v.select(-1, -1).unsqueeze(-1) logits = torch.where(logits < pivot, -float("Inf"), logits) probs = torch.nn.functional.softmax(logits, dim=-1) return probs def sample_top_p(logits, temperature, top_p, vocab_size=8192): probs = torch.softmax(logits[:, -1, :vocab_size] / temperature, dim=-1) probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) mask = probs_sum - probs_sort > top_p probs_sort[mask] = 0.0 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = multinomial_sample_one_no_sync(probs_sort, dtype=torch.int64) next_token = torch.gather(probs_idx, -1, next_token) return next_token def sample_n_top_p(logits, temperature, top_p, vocab_size=8192): probs = torch.softmax(logits[:, :, :vocab_size] / temperature, dim=-1) probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) mask = probs_sum - probs_sort > top_p probs_sort[mask] = 0.0 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = multinomial_sample_one_no_sync(probs_sort, dtype=torch.int64) next_token = torch.gather(probs_idx, -1, next_token) return next_token def sample_n_top_k(logits, temperature: float = 1.0, top_k: Optional[int] = None, vocab_size=8192): if temperature == 0.0: # Modify for multiple logits (n items) idx_next = torch.argmax(logits[:, :, :vocab_size], dim=-1, keepdim=True) # Use all n logits for top-k probs = None else: probs = logits_to_n_probs(logits[:, :, :vocab_size], temperature, top_k) idx_next = multinomial_sample_one_no_sync(probs) return idx_next def logits_to_n_probs( logits, temperature: float = 1.0, top_k: Optional[int] = None, ): logits = logits / max(temperature, 1e-5) if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1)), dim=-1) pivot = v.select(-1, -1).unsqueeze(-1) logits = torch.where(logits < pivot, -float("Inf"), logits) probs = torch.nn.functional.softmax(logits, dim=-1) return probs def decode_one_token( model, input_ids: torch.Tensor, position_ids: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, ): """ Decode a single token from the autoregressive model. """ logits = model(input_ids=input_ids, position_ids=position_ids) if top_p is not None: return sample_top_p(logits, temperature=temperature, top_p=top_p) else: return sample_top_k(logits, temperature=temperature, top_k=top_k) def decode_some_token( model, input_ids: torch.Tensor, position_ids: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, ): """ Decode multi token from the autoregressive model. """ logits = model(input_ids=input_ids, position_ids=position_ids) if top_p is not None: return sample_n_top_p(logits, temperature=temperature, top_p=top_p) else: return sample_n_top_k(logits, temperature=temperature, top_k=top_k) def decode_n_tokens( model, input_ids: torch.Tensor, position_ids: torch.Tensor, num_generate_tokens: int, temperature: float = 1.0, top_p: Optional[float] = 0.8, top_k: Optional[int] = None, decode_one_token_function=decode_one_token, pixnum: int = 336, actnum: int = 11, **kwargs, ): """ Decode n tokens from the autoregressive model. Adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py """ new_tokens = [input_ids] pos_ = position_ids assert ( top_p is None or top_k is None ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}" for t in range(num_generate_tokens): with torch.nn.attention.sdpa_kernel( SDPBackend.MATH ): # Actually better for Inductor to codegen attention here next_token = decode_one_token_function( model, input_ids=input_ids, position_ids=position_ids, temperature=temperature, top_k=top_k, top_p=top_p, ) pos_ += 1 position_ids = pos_ new_tokens.append(next_token.clone()) input_ids = next_token.clone() if (pos_ - pixnum + 1) % (actnum + pixnum) == 0 and t+2 < num_generate_tokens: action = kwargs["action"][ (t+2) // pixnum ] input_ids = torch.cat((input_ids, action), dim=-1) position_ids = torch.tensor([pos_ + _ for _ in range(actnum+1)], dtype=torch.long, device="cuda") pos_ += actnum return new_tokens def decode_n_tokens_for_gradio( model, input_ids: torch.Tensor, position_ids: torch.Tensor, num_generate_tokens: int, temperature: float = 1.0, top_p: Optional[float] = 0.8, top_k: Optional[int] = None, decode_one_token_function=decode_one_token, ): """ Decode n tokens from the autoregressive model. Adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py """ new_tokens = [] assert ( top_p is None or top_k is None ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}" position_id = position_ids[-1].unsqueeze(0) assert num_generate_tokens % 336 == 1, "should be pixnum x n + 1 to fill kvcache" for t in range(num_generate_tokens): with torch.nn.attention.sdpa_kernel( SDPBackend.MATH ): # Actually better for Inductor to codegen attention here next_token = decode_one_token_function( model, input_ids=input_ids, position_ids=position_ids, temperature=temperature, top_k=top_k, top_p=top_p, ) position_id += 1 position_ids = position_id new_tokens.append(next_token.clone()) input_ids = next_token.clone() return new_tokens[:-1], position_id def prefill( model, input_ids: torch.Tensor = None, position_ids: torch.Tensor = None, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = 0.8, **kwargs, ): logits = model(input_ids=input_ids, position_ids=position_ids) # Only top-p or top-k can be provided assert ( top_p is None or top_k is None ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}" if top_p is not None: return sample_top_p(logits, temperature=temperature, top_p=top_p) else: return sample_top_k(logits, temperature=temperature, top_k=top_k) def img_diagd_prepare_inputs( ongoing_row_list, row_token_num, ongoing_input, prompt, imagenum, pixnum: int = 336, actnum: int = 11, columnnum: int = 24, promptlen: int = 347, **kwargs ): position_ids = [] for i in ongoing_row_list: global_idx = promptlen + i * columnnum + row_token_num[i] - 1 + (imagenum - 1) * (pixnum + actnum) position_ids.append(global_idx) if row_token_num[ongoing_row_list[-1]] == 0: append_policy = kwargs.get("append_policy", True) if append_policy: idx_in_input_ids = ongoing_row_list[-1] * columnnum - 1 ongoing_input.append(prompt[:, idx_in_input_ids].unsqueeze(-1)) else: ongoing_input.append(ongoing_input[-1]) input_ids = torch.cat(ongoing_input, dim=1) position_ids = torch.tensor(position_ids, device="cuda") return input_ids, position_ids def img_diagd_decode_n_tokens( model, input_ids: torch.Tensor, position_ids: torch.Tensor, num_generate_tokens: int, temperature: float = 1.0, top_p: Optional[float] = 0.8, top_k: Optional[int] = None, decode_some_token_function=decode_some_token, pixnum: int = 336, actnum: int = 11, columnnum: int = 24, rownum: int = 14, windowsize: int = 2, promptlen: int = 347, **kwargs, ): assert ( top_p is None or top_k is None ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}" imagenum = 1 cur_len = 1 num_generate_tokens += 1 prompt = kwargs.pop("prompt", None) new_tokens = [input_ids.clone()] row_token_num = torch.zeros((rownum,), dtype=torch.long, device="cuda") row_token_num[0] += 1 ongoing_row_list = [0] ongoing_input = [input_ids.clone()] while True: if cur_len >= num_generate_tokens: break if cur_len % pixnum == 0 :#and image_start_token_id_index is None: imagenum += 1 action = kwargs["action"][cur_len // pixnum] ongoing_input.append(action) input_id = torch.cat(ongoing_input, dim=-1) position_ids = torch.arange(imagenum * (pixnum+actnum) - actnum - 1, imagenum * (pixnum+actnum), device="cuda") image_token_num = cur_len % pixnum if image_token_num == 1 and row_token_num[0] == windowsize: ongoing_row_list.append(1) if image_token_num >= 1: input_id, position_ids = img_diagd_prepare_inputs(ongoing_row_list=ongoing_row_list, ongoing_input = ongoing_input, imagenum=imagenum, row_token_num=row_token_num, promptlen=promptlen, prompt=prompt,**kwargs) num_new_tokens = input_id.shape[1] if len(ongoing_row_list) > 0 else 1 with torch.nn.attention.sdpa_kernel( SDPBackend.MATH ): # Actually better for Inductor to codegen attention here next_token = decode_some_token_function( model, input_ids=input_id, position_ids=position_ids, temperature=temperature, top_k=top_k, top_p=top_p, ) ongoing_input = [] if len(ongoing_row_list) == 0: cur_len += 1 ongoing_input.append(next_token[:,-1].clone()) new_tokens.append(next_token[:,-1].clone()) ongoing_row_list.append(0) row_token_num[0] += 1 else: need_remove_row = None cur_len += num_new_tokens for i in range(num_new_tokens): position_in_new_tokens = torch.sum(row_token_num[:(ongoing_row_list[i] + 1)], dim=0) + (imagenum - 1) * pixnum new_tokens.insert(position_in_new_tokens, next_token[:,i].clone()) ongoing_input.append(next_token[:,i].clone()) row_token_num[ongoing_row_list[i]] += 1 if row_token_num[ongoing_row_list[i]] == windowsize and ongoing_row_list[i] < rownum - 1: ongoing_row_list.append(ongoing_row_list[i]+1) elif ongoing_row_list[i] == rownum - 1 and row_token_num[ongoing_row_list[i]] == columnnum: row_token_num = torch.zeros((rownum,), dtype=torch.long, device="cuda") ongoing_row_list = [] ongoing_input = [next_token[:,i]] need_remove_row = None break if row_token_num[ongoing_row_list[i]] == columnnum: ## this row is done ongoing_input.pop() need_remove_row = ongoing_row_list[i] if need_remove_row is not None: ongoing_row_list.remove(need_remove_row) return new_tokens def img_diagd_prepare_inputs_for_gradio( ongoing_row_list, row_token_num, ongoing_input, pixnum: int = 336, actnum: int = 11, columnnum: int = 24, promptlen: int = 347, ): position_ids = [] for i in ongoing_row_list: global_idx = promptlen + i * columnnum + row_token_num[i] - 1 position_ids.append(global_idx) if row_token_num[ongoing_row_list[-1]] == 0: ongoing_input.append(ongoing_input[-1]) input_ids = torch.cat(ongoing_input, dim=1) position_ids = torch.tensor(position_ids, device="cuda") return input_ids, position_ids def img_diagd_decode_n_token_for_gradio( model, input_ids: torch.Tensor, position_ids: torch.Tensor, num_generate_tokens: int, temperature: float = 1.0, top_p: Optional[float] = 0.8, top_k: Optional[int] = None, decode_some_token_function=decode_some_token, pixnum: int = 336, columnnum: int = 24, rownum: int = 14, windowsize: int = 2, ): assert ( top_p is None or top_k is None ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}" cur_len = 0 promptlen = position_ids[-1] + 1 new_tokens = [] row_token_num = torch.zeros((rownum,), dtype=torch.long, device="cuda") ongoing_row_list = [] ongoing_input = [] while True: if cur_len == num_generate_tokens: break image_token_num = cur_len if image_token_num == 1 and row_token_num[0] == windowsize: ongoing_row_list.append(1) if image_token_num == 0: input_id = input_ids if image_token_num >=1: input_id, position_ids = img_diagd_prepare_inputs_for_gradio(ongoing_row_list=ongoing_row_list, ongoing_input = ongoing_input, row_token_num=row_token_num, promptlen=promptlen) num_new_tokens = input_id.shape[1] if len(ongoing_row_list) > 0 else 1 with torch.nn.attention.sdpa_kernel( SDPBackend.MATH ): # Actually better for Inductor to codegen attention here next_token = decode_some_token_function( model, input_ids=input_id, position_ids=position_ids, temperature=temperature, top_k=top_k, top_p=top_p, ) ongoing_input = [] if len(ongoing_row_list) == 0: cur_len += 1 ongoing_input.append(next_token[:,-1].clone()) new_tokens.append(next_token[:,-1].clone()) ongoing_row_list.append(0) row_token_num[0] += 1 else: need_remove_row = None cur_len += num_new_tokens for i in range(num_new_tokens): position_in_new_tokens = torch.sum(row_token_num[:(ongoing_row_list[i] + 1)], dim=0) new_tokens.insert(position_in_new_tokens, next_token[:,i].clone()) ongoing_input.append(next_token[:,i].clone()) row_token_num[ongoing_row_list[i]] += 1 if row_token_num[ongoing_row_list[i]] == windowsize and ongoing_row_list[i] < rownum - 1: ongoing_row_list.append(ongoing_row_list[i]+1) elif ongoing_row_list[i] == rownum - 1 and row_token_num[ongoing_row_list[i]] == columnnum: row_token_num = torch.zeros((rownum,), dtype=torch.long, device="cuda") ongoing_row_list = [] ongoing_input = [next_token[:,i]] need_remove_row = None break if row_token_num[ongoing_row_list[i]] == columnnum: ## this row is done ongoing_input.pop() need_remove_row = ongoing_row_list[i] if need_remove_row is not None: ongoing_row_list.remove(need_remove_row) with torch.nn.attention.sdpa_kernel( SDPBackend.MATH ): # Actually better for Inductor to codegen attention here _ = decode_some_token_function( model, input_ids=next_token[:,-1], position_ids=position_ids+1, temperature=temperature, top_k=top_k, top_p=top_p, ) return new_tokens, position_ids+2 def vid_diagd_prepare_inputs( ongoing_row_list_v, row_token_num_v, ongoing_input_v, prompt, pixnum: int = 336, actnum: int = 11, rownum: int = 14, columnnum: int = 24, promptlen: int = 347, **kwargs ): new_frame = False position_ids = [] for i in ongoing_row_list_v: global_idx = promptlen + i * columnnum + row_token_num_v[i // rownum][i % rownum] -1 + (i // rownum) * actnum position_ids.append(global_idx) lastrow = ongoing_row_list_v[-1] if lastrow % rownum == 0 and row_token_num_v[lastrow // rownum][lastrow % rownum] == 0: # WARNING action = kwargs["action"][lastrow // rownum] ongoing_input_v.append(action) position_ids.pop() pos_act = torch.arange( promptlen + (lastrow // rownum) * (pixnum+actnum) - actnum, promptlen + (lastrow // rownum) * (pixnum+actnum), device="cuda") position_ids.extend(pos_act.unbind()) new_frame = True elif row_token_num_v[lastrow // rownum][lastrow % rownum] == 0: append_policy = kwargs.get("append_policy", True) if append_policy: idx_in_input_ids = (lastrow % rownum) * columnnum - 1 ongoing_input_v.append(prompt[:, idx_in_input_ids].unsqueeze(-1)) else: ongoing_input_v.append(ongoing_input_v[-1]) input_ids = torch.cat(ongoing_input_v, dim=1) position_ids = torch.tensor(position_ids, device="cuda") return input_ids, position_ids, new_frame def video_diagd_decode_n_tokens( model, input_ids: torch.Tensor, position_ids: torch.Tensor, num_generate_tokens: int, temperature: float = 1.0, top_p: Optional[float] = 0.8, top_k: Optional[int] = None, decode_some_token_function=decode_some_token, pixnum: int = 336, actnum: int = 11, columnnum: int = 24, rownum: int = 14, windowsize: int = 2, promptlen: int = 347, **kwargs, ): assert ( top_p is None or top_k is None ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}" cur_len = 1 num_generate_tokens += 1 prompt = kwargs.pop("prompt", None) new_tokens = [input_ids.clone()] row_token_num_v = [] ongoing_row_list_v = [0] row_token_num_v.append(torch.zeros((rownum,), dtype=torch.long, device="cuda")) row_token_num_v[0][0] += 1 if row_token_num_v[0][0] == windowsize: ongoing_row_list_v.append(1) ongoing_input_v = [input_ids.clone()] while True: if cur_len >= num_generate_tokens: break input_id, position_ids, new_frame = vid_diagd_prepare_inputs(ongoing_row_list_v=ongoing_row_list_v, ongoing_input_v = ongoing_input_v, row_token_num_v=row_token_num_v, promptlen=promptlen, prompt=prompt, **kwargs) num_new_tokens = input_id.shape[1] with torch.nn.attention.sdpa_kernel( SDPBackend.MATH ): # Actually better for Inductor to codegen attention here next_token = decode_some_token_function( model, input_ids=input_id, position_ids=position_ids, temperature=temperature, top_k=top_k, top_p=top_p, ) ongoing_input_v = [] if new_frame: next_token = torch.cat([next_token[:,:-actnum], next_token[:,-1:]], dim=1) num_new_tokens = num_new_tokens - actnum + 1 need_remove_row = None cur_len += num_new_tokens for i in range(num_new_tokens): last_frame = torch.stack(row_token_num_v[:ongoing_row_list_v[i] // rownum]).sum() if ongoing_row_list_v[i] // rownum > 0 else torch.tensor(0, dtype=torch.long, device="cuda") position_in_new_tokens = last_frame + torch.sum(row_token_num_v[ongoing_row_list_v[i] // rownum][:(ongoing_row_list_v[i] % rownum + 1)], dim=0) new_tokens.insert(position_in_new_tokens, next_token[:,i].clone()) ongoing_input_v.append(next_token[:,i].clone()) row_token_num_v[ongoing_row_list_v[i] // rownum][ongoing_row_list_v[i] % rownum] += 1 # WARNING if row_token_num_v[ongoing_row_list_v[i] // rownum][ongoing_row_list_v[i] % rownum] == windowsize and ongoing_row_list_v[i] < rownum * (num_generate_tokens//pixnum) - 1: ongoing_row_list_v.append(ongoing_row_list_v[i]+1) if ongoing_row_list_v[-1] % rownum == 0: row_token_num_v.append(torch.zeros((rownum,), dtype=torch.long, device="cuda")) if row_token_num_v[ongoing_row_list_v[i] // rownum][ongoing_row_list_v[i] % rownum] == columnnum: ongoing_input_v.pop() need_remove_row = ongoing_row_list_v[i] if need_remove_row is not None: ongoing_row_list_v.remove(need_remove_row) return new_tokens ================================================ FILE: inference.py ================================================ import os import cv2 import torch import time import numpy as np from tqdm import tqdm from rich import print from PIL import Image from pathlib import Path from torch import autocast from einops import rearrange from mcdataset import MCDataset from omegaconf import OmegaConf from torchvision import transforms from argparse import ArgumentParser from utils import load_model, tensor_to_uint8 torch.backends.cuda.matmul.allow_tf32 = False ACCELERATE_ALGO = [ 'naive','image_diagd' ] TARGET_SIZE=(224,384) TOKEN_PER_IMAGE = 347 # IMAGE = PIX+ACTION TOKEN_PER_PIX = 336 safe_globals = {"array": np.array} def token2video(code_list, tokenizer, save_path, fps, device = 'cuda'): """ change log: we don't perform path processing inside functions to enable extensibility save_path: str, path to save the video, expect to endwith .mp4 """ if len(code_list) % TOKEN_PER_PIX != 0: print(f"code_list length {len(code_list)} is not multiple of {TOKEN_PER_PIX}") return num_images = len(code_list) // TOKEN_PER_PIX fourcc = cv2.VideoWriter_fourcc(*'mp4v') video = cv2.VideoWriter(save_path, fourcc, fps, (384, 224)) for i in range(num_images): code = code_list[i*TOKEN_PER_PIX:(i+1)*TOKEN_PER_PIX] code = torch.tensor([int(x) for x in code], dtype=torch.long).to(device) img = tokenizer.token2image(code) # pixel frame = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) video.write(frame) video.release() def get_args(): parser = ArgumentParser() parser.add_argument('--data_root', type=str, required=True) parser.add_argument('--model_ckpt', type=str, required=True) parser.add_argument('--config', type=str, required=True) parser.add_argument('--output_dir', type=str, required=True) parser.add_argument('--demo_num', type=int, default=1) parser.add_argument('--frames', type=int, required=True) parser.add_argument('--window_size', type=int, default=2) parser.add_argument('--accelerate-algo', type=str, default='naive', help=f"Accelerate Algorithm Option: {ACCELERATE_ALGO}") parser.add_argument('--fps', type=int, default=6) group = parser.add_mutually_exclusive_group(required=True) group.add_argument('--top_k', type=int, help='Use top-k sampling') group.add_argument('--top_p', type=float, help='Use top-p (nucleus) sampling') parser.add_argument('--val_data_num', type=int, default=500, help="number of validation data") args = parser.parse_args() return args def lvm_generate(args, model, output_dir, demo_video): """ """ ### 1. set video input/output path input_mp4_path = os.path.join(args.data_root, demo_video) input_action_path = os.path.join(args.data_root, demo_video.replace('mp4','jsonl')) output_mp4_path = str(output_dir / demo_video) output_action_path = output_mp4_path.replace('.mp4', '.jsonl') # backup action os.system(f"cp {input_action_path} {output_action_path}") if os.path.exists(output_mp4_path): print(f"output path {output_mp4_path} exist") return {} device = model.transformer.device ### 2. load action into list action_list = [] mcdataset = MCDataset() with open(input_action_path, 'r') as f: for line in f: line = eval(line.strip(), {"__builtins__": None}, safe_globals) line['camera'] = np.array(line['camera']) act_index = mcdataset.get_action_index_from_actiondict(line, action_vocab_offset=8192) action_list.append(act_index) ### 3. load video frames cap = cv2.VideoCapture(input_mp4_path) start_frame = 0 end_frame = args.demo_num frames = [] for frame_idx in range(start_frame, end_frame): cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) ret, frame = cap.read() if not ret: print(f"Error in reading frame {frame_idx}") continue cv2.cvtColor(frame, code=cv2.COLOR_BGR2RGB, dst=frame) frame = np.asarray(np.clip(frame, 0, 255), dtype=np.uint8) frame = torch.from_numpy(frame) frames.append(frame) frames = torch.stack(frames, dim=0).to(device) frames = frames.permute(0, 3, 1, 2) frames = frames.float() / 255.0 normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) frames = normalize(frames) with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16): img_index = model.tokenizer.tokenize_images(frames) img_index = rearrange(img_index, '(b t) h w -> b t (h w)', b=1) all_generated_tokens = [] action_all = action_list[end_frame: end_frame + args.frames] action_all = torch.tensor(action_all).unsqueeze(1).to(device) image_input = rearrange(img_index, 'b t c -> b (t c)') start_t = time.time() with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16): if args.accelerate_algo == 'naive': outputs = model.transformer.naive_generate( input_ids=image_input, max_new_tokens=TOKEN_PER_PIX*args.frames, action_all=action_all, top_k=args.top_k, top_p=args.top_p) elif args.accelerate_algo == 'image_diagd': outputs = model.transformer.img_diagd_generate(input_ids=image_input, max_new_tokens=TOKEN_PER_PIX*args.frames, action_all=action_all,windowsize = args.window_size, top_k=args.top_k, top_p=args.top_p) else: raise ValueError(f"Unknown accelerate algorithm {args.accelerate_algo}") end_t = time.time() all_generated_tokens.extend(outputs.tolist()[0]) new_length = len(all_generated_tokens) time_costed = end_t - start_t token_per_sec = new_length / time_costed frame_per_sec = token_per_sec / TOKEN_PER_PIX print(f"{new_length} token generated; cost {time_costed:.3f} second; {token_per_sec:.3f} token/sec {frame_per_sec:.3f} fps") token2video(all_generated_tokens, model.tokenizer, str(output_path / demo_video), args.fps, device) # return for evaluation return_item = { "time_costed": time_costed, "token_num": new_length, } return return_item if __name__ == '__main__': args = get_args() config = OmegaConf.load(args.config) output_path = Path(args.output_dir) precision_scope = autocast os.makedirs(output_path, exist_ok=True) model = load_model(config, args.model_ckpt, gpu=True, eval_mode=True) print(f"[bold magenta][MINEWORLD][INFERENCE][/bold magenta] Load Model From {args.model_ckpt}") # get accelearte algoritm args.accelerate_algo = args.accelerate_algo.lower() if args.accelerate_algo not in ACCELERATE_ALGO: print(f"[bold red][Warning][/bold red] {args.accelerate_algo} is not in {ACCELERATE_ALGO}, use naive") args.accelerate_algo = 'naive' num_item = 0 for root, _, files in os.walk(args.data_root): files = [f for f in files if f.endswith('.mp4')] # mp4 would not influence progress bar files = sorted(files, key=lambda x: int(x.split('_')[1].split('.')[0])) for file in tqdm(files): return_item = lvm_generate(args, model, output_path,file) num_item += 1 if num_item >= args.val_data_num: print(f"[bold magenta][MINEWORLD][INFERENCE][/bold magenta] reach val data num limit {args.val_data_num}") break ================================================ FILE: lvm.py ================================================ """ Wrap the Huggingface Transformers Llama to PyTorch Lightning Module. """ import os import sys import inspect import torch from typing import Optional import torch.utils.checkpoint from torch import nn from transformers.activations import ACT2FN from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from transformers import LlamaConfig from utils import get_obj_from_str, instantiate_from_config from diagonal_decoding import decode_one_token, decode_some_token, decode_n_tokens, decode_n_tokens_for_gradio, prefill, img_diagd_decode_n_tokens, video_diagd_decode_n_tokens, img_diagd_decode_n_token_for_gradio torch.backends.cuda.matmul.allow_tf32 = False logger = logging.get_logger(__name__) currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) parentdir = os.path.dirname(currentdir) if not (parentdir in sys.path): sys.path.insert(0, parentdir) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids): """ Apply rotary position embeddings to query and key tensors. Args: q (torch.Tensor): Query tensor. k (torch.Tensor): Key tensor. cos (torch.Tensor): Cosine values. sin (torch.Tensor): Sine values. position_ids (torch.Tensor): Position IDs. Returns: torch.Tensor: Query and key tensors with rotary position embeddings applied. """ cos = cos[position_ids].unsqueeze(0).unsqueeze(2) sin = sin[position_ids].unsqueeze(0).unsqueeze(2) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class LlamaLVM(torch.nn.Module): def __init__( self, transformer_config, model_class: str, tokenizer_config = None, ): super().__init__() self.config = instantiate_from_config(transformer_config) self.transformer = get_obj_from_str(model_class)(self.config) if tokenizer_config is not None: self.tokenizer = instantiate_from_config(tokenizer_config) class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class LlamaRotaryEmbedding(nn.Module): def __init__( self, device=None, config: Optional[LlamaConfig] = None, ): super().__init__() self.rope_kwargs = {} self.rope_type = "default" self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] self.max_position_embeddings = config.max_position_embeddings inv_freq, _ = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) self._set_cos_sin_cache( device=self.inv_freq.device, dtype=torch.get_default_dtype(), ) def _set_cos_sin_cache(self, device, dtype): """ Set the cosine and sine cache for positional embeddings. Args: seq_len (int): The sequence length. device (str): The device on which the cache tensors will be stored. dtype: The data type of the cache tensors. """ t = torch.arange( self.max_position_embeddings, device=device, dtype=self.inv_freq.dtype ) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer( "cos_cached", emb.cos().to(dtype), persistent=False ) self.register_buffer( "sin_cached", emb.sin().to(dtype), persistent=False ) def forward(self, x, seq_len=None): """ Forward pass of the LlamaRotaryEmbedding module. Args: x (torch.Tensor): Input tensor of shape [bs, num_attention_heads, seq_len, head_size]. seq_len (int): The sequence length. If greater than the cached length, the cache will be updated. Returns: tuple: A tuple containing two tensors, the cosine and sine embeddings, both of shape [1, 1, seq_len, dim]. """ if seq_len > self.max_position_embeddings: raise ValueError("seq length should less than max embedding") return ( self.cos_cached[:seq_len, :].to(dtype=x.dtype), self.sin_cached[:seq_len, :].to(dtype=x.dtype), ) class LlamaMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class LlamaAttention(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings assert (self.head_dim * self.num_heads) == self.hidden_size, "hidden_size must be divisible by num_heads" self.q_proj = nn.Linear( self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear( self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias ) self.rotary_emb = LlamaRotaryEmbedding(config=config) self.max_batch_size = getattr(config, "max_batch_size", 1) self.init_kv_cache() def init_kv_cache(self, dtype=torch.float16): cache_shape = (self.max_batch_size, self.max_position_embeddings, self.num_key_value_heads, self.head_dim) self.cache_k = torch.zeros(cache_shape, dtype=dtype).cuda() self.cache_v = torch.zeros(cache_shape, dtype=dtype).cuda() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, positions_embedding = None, ): bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view( bsz, q_len, self.num_heads, self.head_dim ) key_states = key_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ) value_states = value_states.view( bsz, q_len, self.num_key_value_heads, self.head_dim ) cos, sin = positions_embedding query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) self.cache_k[:bsz, position_ids] = key_states self.cache_v[:bsz, position_ids] = value_states key_states, value_states = ( self.cache_k[:bsz, :, :], self.cache_v[:bsz, :, :], ) key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=2) value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=2) query_states, key_states, value_states = map(lambda x: x.transpose(1, 2), (query_states, key_states, value_states)) attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, ).transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size self.self_attn = LlamaAttention(config=config) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, positions_embedding = None, ): """ Forward pass for the LlamaDecoderLayer. Args: hidden_states (torch.FloatTensor): Input tensor of shape `(batch, seq_len, embed_dim)`. attention_mask (torch.FloatTensor, optional): Attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. position_ids (torch.LongTensor, optional): Positional IDs tensor. Returns: Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: Tuple containing: - hidden_states (torch.FloatTensor): Output tensor. """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, positions_embedding=positions_embedding, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class LlamaModel(PreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] Args: config: LlamaConfig """ def __init__(self, config: LlamaConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)] ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) self.max_position_embedding = config.max_position_embeddings self.causal_mask = torch.tril( torch.ones(self.max_position_embedding, self.max_position_embedding, dtype=torch.bool) ).cuda() self.post_init() def _create_attention_mask(self, input_pos: Optional[torch.Tensor]): """ Creates an attention mask for the transformer layers. Args: input_pos[torch.Tensor]: The position of input sequence (used for inference only). Returns: Optional[torch.Tensor]: The attention mask, or None for causal mask. """ mask = self.causal_mask[input_pos] return mask def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ): if input_ids is None: raise ValueError( "decoder_input_ids is None" ) hidden_states = self.embed_tokens(input_ids) positions_embedding = self.rotary_emb(hidden_states, seq_len=self.max_position_embedding) attention_mask = self._create_attention_mask(input_pos=position_ids) for idx, decoder_layer in enumerate(self.layers): layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, positions_embedding=positions_embedding, ) hidden_states = layer_outputs hidden_states = self.norm(hidden_states) return hidden_states class LlamaForCausalLM(PreTrainedModel): def __init__(self, config): super().__init__(config) self.model = LlamaModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def forward( self, input_ids: torch.LongTensor = None, position_ids: Optional[torch.LongTensor] = None, ): outputs = self.model( input_ids=input_ids, position_ids=position_ids, ) logits = self.lm_head(outputs[:, :, :]) return logits def refresh_kvcache(self): for i in self.model.layers: i.self_attn.init_kv_cache() def naive_generate(self, input_ids, max_new_tokens, temperature=1.0, action_all=None, top_p=None, top_k=None): self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True) if action_all is not None: input_ids = torch.cat([input_ids, action_all[0]], dim=-1) position_ids = torch.arange(0, input_ids.shape[1], device="cuda") next_token = self.prefill( self, input_ids=input_ids, position_ids=position_ids, temperature=temperature, top_k = top_k, top_p = top_p, ) self.decode_one_token = torch.compile(decode_one_token, mode="max-autotune", fullgraph=True) position_ids = torch.tensor([input_ids.shape[1]], dtype=torch.long, device="cuda") generated_tokens = decode_n_tokens( self, input_ids = next_token.view(1, -1), position_ids = position_ids, num_generate_tokens = max_new_tokens - 1, temperature = temperature, decode_one_token_function=self.decode_one_token, action=action_all, top_p = top_p, top_k = top_k, ) return torch.cat(generated_tokens, dim=1) def prefill_for_gradio(self, input_ids, temperature=1.0): self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True) last_pos = input_ids.shape[1] position_ids = torch.arange(0, last_pos, device="cuda") with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16): next_token = self.prefill( self, input_ids=input_ids, position_ids=position_ids, temperature=temperature, ) return next_token, last_pos def decode_img_token_for_gradio(self, input_action, position_id, max_new_tokens, temperature=1.0): self.decode_one_token = torch.compile(decode_one_token, mode="max-autotune", fullgraph=True) # self.decode_one_token = decode_one_token # WARNING position_ids = torch.arange(position_id, position_id + input_action.shape[1], device="cuda") with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16): generated_tokens, position_id = decode_n_tokens_for_gradio( self, input_ids = input_action, position_ids = position_ids, num_generate_tokens = max_new_tokens, temperature = temperature, decode_one_token_function=self.decode_one_token, ) # WARNING return generated_tokens, position_id def diagd_img_token_for_gradio(self, input_action, position_id, max_new_tokens, temperature=1.0, windowsize=2): self.decode_some_token = torch.compile(decode_some_token, mode="max-autotune", fullgraph=True) position_ids = torch.arange(position_id, position_id + input_action.shape[1], device="cuda") with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16): generated_tokens, position_id = img_diagd_decode_n_token_for_gradio( self, input_ids = input_action, position_ids = position_ids, num_generate_tokens = max_new_tokens, temperature = temperature, decode_some_token_function=self.decode_some_token, windowsize = windowsize, ) return generated_tokens, position_id def img_diagd_generate(self, input_ids, max_new_tokens, temperature=1.0, action_all=None, windowsize=2, top_p=None, top_k=None): self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True) input_ids = torch.cat([input_ids, action_all[0]], dim=-1) position_ids = torch.arange(0, input_ids.shape[1], device="cuda") next_token = self.prefill( self, input_ids=input_ids, position_ids=position_ids, temperature=temperature, top_k = top_k, top_p = top_p, ) self.decode_some_token = torch.compile(decode_some_token, mode="max-autotune", fullgraph=True) position_ids = torch.tensor([input_ids.shape[1]], dtype=torch.long, device="cuda") generated_tokens = img_diagd_decode_n_tokens( self, input_ids = next_token.view(1, -1), position_ids = position_ids, num_generate_tokens = max_new_tokens - 1, temperature = temperature, decode_some_token_function=self.decode_some_token, windowsize = windowsize, action=action_all, prompt=input_ids, top_k = top_k, top_p = top_p, ) return torch.cat(generated_tokens, dim=1) def vid_diagd_generate(self, input_ids, max_new_tokens,windowsize=2, temperature=1.0, action_all=None,**kwargs): self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True) input_ids = torch.cat([input_ids, action_all[0]], dim=-1) position_ids = torch.arange(0, input_ids.shape[1], device="cuda") next_token = self.prefill( self, input_ids=input_ids, position_ids=position_ids, temperature=temperature, ) self.decode_some_token = torch.compile(decode_some_token, mode="max-autotune", fullgraph=True) # self.decode_some_token = decode_some_token position_ids = torch.tensor([input_ids.shape[1]], dtype=torch.long, device="cuda") generated_tokens = video_diagd_decode_n_tokens( self, input_ids = next_token.view(1, -1), position_ids = position_ids, num_generate_tokens = max_new_tokens - 1, temperature = temperature, decode_some_token_function=self.decode_some_token, windowsize = windowsize, action=action_all, prompt=input_ids, **kwargs ) return torch.cat(generated_tokens, dim=1) ================================================ FILE: mcdataset.py ================================================ #!/usr/bin/env python3 # -*- coding: utf-8 -*- import os import json import attr import collections import numpy as np from typing import Union, Dict import torch from utils import print0 # https://github.com/openai/Video-Pre-Training/blob/aed46b90e8db2332801feabd8be2de01f92c0ad2/run_inverse_dynamics_model.py#L17 KEYBOARD_BUTTON_MAPPING = { "key.keyboard.escape" :"ESC", "key.keyboard.s" :"back", "key.keyboard.q" :"drop", "key.keyboard.w" :"forward", "key.keyboard.1" :"hotbar.1", "key.keyboard.2" :"hotbar.2", "key.keyboard.3" :"hotbar.3", "key.keyboard.4" :"hotbar.4", "key.keyboard.5" :"hotbar.5", "key.keyboard.6" :"hotbar.6", "key.keyboard.7" :"hotbar.7", "key.keyboard.8" :"hotbar.8", "key.keyboard.9" :"hotbar.9", "key.keyboard.e" :"inventory", "key.keyboard.space" :"jump", "key.keyboard.a" :"left", "key.keyboard.d" :"right", "key.keyboard.left.shift" :"sneak", "key.keyboard.left.control" :"sprint", "key.keyboard.f" :"swapHands", } # https://github.com/openai/Video-Pre-Training/blob/aed46b90e8db2332801feabd8be2de01f92c0ad2/run_inverse_dynamics_model.py#L41 # Template action NOOP_ACTION = { "ESC": 0, "back": 0, "drop": 0, "forward": 0, "hotbar.1": 0, "hotbar.2": 0, "hotbar.3": 0, "hotbar.4": 0, "hotbar.5": 0, "hotbar.6": 0, "hotbar.7": 0, "hotbar.8": 0, "hotbar.9": 0, "inventory": 0, "jump": 0, "left": 0, "right": 0, "sneak": 0, "sprint": 0, "swapHands": 0, "camera": np.array([0, 0]), # [y, x] "attack": 0, "use": 0, "pickItem": 0, } OASIS_ACTION_KEYS = [ "inventory", "ESC", "hotbar.1", "hotbar.2", "hotbar.3", "hotbar.4", "hotbar.5", "hotbar.6", "hotbar.7", "hotbar.8", "hotbar.9", "forward", "back", "left", "right", "cameraX", "cameraY", "jump", "sneak", "sprint", "swapHands", "attack", "use", "pickItem", "drop", ] # Matches a number in the MineRL Java code regarding sensitivity # This is for mapping from recorded sensitivity to the one used in the model CAMERA_SCALER = 360.0 / 2400.0 # https://github.com/openai/Video-Pre-Training/blob/main/lib/actions.py#L8 with some modifications class Buttons: # 14 in total without hotbar and camera ATTACK = "attack" BACK = "back" FORWARD = "forward" JUMP = "jump" LEFT = "left" RIGHT = "right" SNEAK = "sneak" SPRINT = "sprint" USE = "use" DROP = "drop" INVENTORY = "inventory" # added by Yang ESC = "ESC" SWAPHANDS = "swapHands" PICKITEM = "pickItem" ALL = [ USE, ATTACK, FORWARD, BACK, LEFT, RIGHT, JUMP, SNEAK, SPRINT, DROP, SWAPHANDS, PICKITEM, INVENTORY, ESC, ] + [f"hotbar.{i}" for i in range(1, 10)] class QuantizationScheme: LINEAR = "linear" MU_LAW = "mu_law" # https://github.com/openai/Video-Pre-Training/blob/main/lib/actions.py#L49 @attr.s(auto_attribs=True) class CameraQuantizer: """ A camera quantizer that discretizes and undiscretizes a continuous camera input with y (pitch) and x (yaw) components. Parameters: - camera_binsize: The size of the bins used for quantization. In case of mu-law quantization, it corresponds to the average binsize. - camera_maxval: The maximum value of the camera action. - quantization_scheme: The quantization scheme to use. Currently, two quantization schemes are supported: - Linear quantization (default): Camera actions are split uniformly into discrete bins - Mu-law quantization: Transforms the camera action using mu-law encoding (https://en.wikipedia.org/wiki/%CE%9C-law_algorithm) followed by the same quantization scheme used by the linear scheme. - mu: Mu is the parameter that defines the curvature of the mu-law encoding. Higher values of mu will result in a sharper transition near zero. Below are some reference values listed for choosing mu given a constant maxval and a desired max_precision value. maxval = 10 | max_precision = 0.5 | μ ≈ 2.93826 maxval = 10 | max_precision = 0.4 | μ ≈ 4.80939 maxval = 10 | max_precision = 0.25 | μ ≈ 11.4887 maxval = 20 | max_precision = 0.5 | μ ≈ 2.7 maxval = 20 | max_precision = 0.4 | μ ≈ 4.39768 maxval = 20 | max_precision = 0.25 | μ ≈ 10.3194 maxval = 40 | max_precision = 0.5 | μ ≈ 2.60780 maxval = 40 | max_precision = 0.4 | μ ≈ 4.21554 maxval = 40 | max_precision = 0.25 | μ ≈ 9.81152 """ camera_maxval: int camera_binsize: int quantization_scheme: str = attr.ib( default=QuantizationScheme.LINEAR, validator=attr.validators.in_([QuantizationScheme.LINEAR, QuantizationScheme.MU_LAW]), ) mu: float = attr.ib(default=5) def discretize(self, xy): xy = np.clip(xy, -self.camera_maxval, self.camera_maxval) if self.quantization_scheme == QuantizationScheme.MU_LAW: xy = xy / self.camera_maxval v_encode = np.sign(xy) * (np.log(1.0 + self.mu * np.abs(xy)) / np.log(1.0 + self.mu)) v_encode *= self.camera_maxval xy = v_encode # Quantize using linear scheme return np.round((xy + self.camera_maxval) / self.camera_binsize).astype(np.int64) def undiscretize(self, xy): xy = xy * self.camera_binsize - self.camera_maxval if self.quantization_scheme == QuantizationScheme.MU_LAW: xy = xy / self.camera_maxval v_decode = np.sign(xy) * (1.0 / self.mu) * ((1.0 + self.mu) ** np.abs(xy) - 1.0) v_decode *= self.camera_maxval xy = v_decode return xy class MCDataset(torch.utils.data.Dataset): """ Dataset for Minecraft. """ def __init__(self, action_length: int = 11, # including bos and eos camera_binsize: int = 9, # 2 in vpt camera_maxval: int = 90, # 10 in vpt camera_mu: float = 11.4887, # 10 in vpt quantization_scheme: str = "mu_law", ): self.action_length = action_length self.camera_quantizer = CameraQuantizer( camera_binsize=camera_binsize, camera_maxval=camera_maxval, mu=camera_mu, quantization_scheme=quantization_scheme, ) def json_action_to_env_action(self, json_action): """ https://github.com/openai/Video-Pre-Training/blob/aed46b90e8db2332801feabd8be2de01f92c0ad2/run_inverse_dynamics_model.py#L80 Converts a json action into a MineRL action. Returns (minerl_action, is_null_action) """ # This might be slow... env_action = NOOP_ACTION.copy() # As a safeguard, make camera action again so we do not override anything env_action["camera"] = np.array([0, 0]) is_null_action = True keyboard_keys = json_action["keyboard"]["keys"] for key in keyboard_keys: # You can have keys that we do not use, so just skip them # NOTE in original training code, ESC was removed and replaced with # "inventory" action if GUI was open. # Not doing it here, as BASALT uses ESC to quit the game. if key in KEYBOARD_BUTTON_MAPPING: env_action[KEYBOARD_BUTTON_MAPPING[key]] = 1 is_null_action = False mouse = json_action["mouse"] camera_action = env_action["camera"] camera_action[0] = mouse["dy"] * CAMERA_SCALER camera_action[1] = mouse["dx"] * CAMERA_SCALER if mouse["dx"] != 0 or mouse["dy"] != 0: is_null_action = False else: if abs(camera_action[0]) > 180: camera_action[0] = 0 if abs(camera_action[1]) > 180: camera_action[1] = 0 mouse_buttons = mouse["buttons"] if 0 in mouse_buttons: env_action["attack"] = 1 is_null_action = False if 1 in mouse_buttons: env_action["use"] = 1 is_null_action = False if 2 in mouse_buttons: env_action["pickItem"] = 1 is_null_action = False # added by Yang # if two confictory actions are pressed, remove them if env_action["forward"] == 1 and env_action["back"] == 1: env_action["forward"] = 0 env_action["back"] = 0 if env_action["left"] == 1 and env_action["right"] == 1: env_action["left"] = 0 env_action["right"] = 0 if env_action["jump"] == 1 and env_action["sneak"] == 1: env_action["jump"] = 0 env_action["sneak"] = 0 if env_action["sprint"] == 1 and env_action["sneak"] == 1: env_action["sprint"] = 0 env_action["sneak"] = 0 if env_action["attack"] == 1 and env_action["use"] == 1: env_action["attack"] = 0 env_action["use"] = 0 # remove inventory and ESC action if env_action["inventory"] == 1: is_null_action = True if env_action["ESC"] == 1: is_null_action = True return env_action, is_null_action def make_action_vocab(self, num_cam_bins: int = 21, action_vocab_offset: int = 0, verbose: bool = False): action_vocab = collections.OrderedDict() # 14 actions and hotbar.1-9 for i, action in enumerate(Buttons.ALL): action_vocab[action] = i # camera 0 for i in range(num_cam_bins): action_vocab[f"cam_0_{i}"] = len(Buttons.ALL) + i # camera 1 for i in range(num_cam_bins): action_vocab[f"cam_1_{i}"] = len(Buttons.ALL) + num_cam_bins + i # bos, null, eos action_vocab[""] = len(Buttons.ALL) + 2 * num_cam_bins action_vocab[""] = len(Buttons.ALL) + 2 * num_cam_bins + 1 action_vocab[""] = len(Buttons.ALL) + 2 * num_cam_bins + 2 if action_vocab_offset > 0: action_vocab = {k: v + action_vocab_offset for k, v in action_vocab.items()} if verbose: print0(f"[bold yellow]\[MCDataset][/bold yellow] Action Vocab: {action_vocab}") self.action_vocab = action_vocab # return action_vocab def _handle_conflict_action_index(self, action_dict: Dict[str, Union[int, np.ndarray]], key1: str, key2: str, null_key: str, verbose: bool = False): if action_dict[key1] == 1 and action_dict[key2] == 1: if verbose: print0(f"[bold yellow]\[MCDataset][/bold yellow] {key1} and {key2} are both pressed") return self.action_vocab[null_key] elif action_dict[key1] == 1: return self.action_vocab[key1] elif action_dict[key2] == 1: return self.action_vocab[key2] else: return self.action_vocab[null_key] def get_action_index_from_actiondict(self, action_dict: Dict[str, Union[int, np.ndarray]], action_vocab_offset: int = 0, verbose: bool = False): if not hasattr(self, "action_vocab"): self.make_action_vocab(action_vocab_offset=action_vocab_offset, verbose=verbose) # action_list = [boa, camy, camx, hotbar, fore_back, left_right, sprint_sneak, use_attack, jump, drop_pick, eoa] # 11 actions action_list = [self.action_vocab[""]] * self.action_length # 0 & 10 action_list[0] = self.action_vocab[""] action_list[-1] = self.action_vocab[""] camera_action = action_dict["camera"] assert len(camera_action) == 2, f"[MCDataset] camera_action length is not 2: {camera_action}" # camera_action should be numpy array if not isinstance(camera_action, np.ndarray): camera_action = np.array(camera_action) camera_action = self.camera_quantizer.discretize(camera_action) # 1 & 2 action_list[1] = self.action_vocab[f"cam_0_{camera_action[0]}"] action_list[2] = self.action_vocab[f"cam_1_{camera_action[1]}"] # 3 for i in range(1, 10): if f"hotbar.{i}" in action_dict and action_dict[f"hotbar.{i}"] == 1: action_list[3] = self.action_vocab[f"hotbar.{i}"] break # 4 forward/back action_list[4] = self._handle_conflict_action_index(action_dict, "forward", "back", "", verbose=verbose) # 5 left/right action_list[5] = self._handle_conflict_action_index(action_dict, "left", "right", "", verbose=verbose) # 6 sprint/sneak action_list[6] = self._handle_conflict_action_index(action_dict, "sprint", "sneak", "", verbose=verbose) # 7 use/attack action_list[7] = self._handle_conflict_action_index(action_dict, "use", "attack", "", verbose=verbose) # 8 jump action_list[8] = self.action_vocab["jump"] if action_dict["jump"] == 1 else self.action_vocab[""] # 9 drop/pick action_list[9] = self._handle_conflict_action_index(action_dict, "drop", "pickItem", "", verbose=verbose) if verbose: print0(f"[bold yellow]\[MCDataset][/bold yellow] Action List: {action_list}") return action_list def read_jsonl(self, jsonl_path: str): assert os.path.isfile(jsonl_path), f"[MCDataset] {jsonl_path} does not exist" # read jsonl # https://github.com/openai/Video-Pre-Training/blob/main/data_loader.py#L76 try: with open(jsonl_path) as json_file: json_lines = json_file.readlines() json_data = "[" + ",".join(json_lines) + "]" json_data = json.loads(json_data) except Exception as e: print0(f"[bold yellow]\[MCDataset][/bold yellow] {jsonl_path} cannot be read: {e}") return None return json_data ================================================ FILE: metrics/IDM/inverse_dynamics_model.py ================================================ # Borrowed from VPT (https://github.com/openai/Video-Pre-Training) import numpy as np import torch as th import cv2 from gym3.types import DictType from gym import spaces from tqdm import tqdm import os from argparse import ArgumentParser import pickle import cv2 import json from lib.action_mapping import IDMActionMapping from lib.actions import ActionTransformer from lib.policy import InverseActionPolicy from lib.torch_util import default_device_type, set_default_torch_device from sklearn.metrics import precision_score, recall_score, f1_score # Hardcoded settings AGENT_RESOLUTION = (128, 128) safe_globals = {"array": np.array} def resize_image(img, target_resolution): # For your sanity, do not resize with any function than INTER_LINEAR img = cv2.resize(img, target_resolution, interpolation=cv2.INTER_LINEAR) return img ACTION_TRANSFORMER_KWARGS = dict( camera_binsize=2, camera_maxval=10, camera_mu=10, camera_quantization_scheme="mu_law", ) class IDMAgent: """ Sugarcoating on the inverse dynamics model (IDM) used to predict actions Minecraft players take in videos. Functionally same as MineRLAgent. """ def __init__(self, idm_net_kwargs, pi_head_kwargs, device=None): if device is None: device = default_device_type() self.device = th.device(device) # Set the default torch device for underlying code as well set_default_torch_device(self.device) self.action_mapper = IDMActionMapping(n_camera_bins=11) action_space = self.action_mapper.get_action_space_update() action_space = DictType(**action_space) self.action_transformer = ActionTransformer(**ACTION_TRANSFORMER_KWARGS) idm_policy_kwargs = dict(idm_net_kwargs=idm_net_kwargs, pi_head_kwargs=pi_head_kwargs, action_space=action_space) self.policy = InverseActionPolicy(**idm_policy_kwargs).to(device) self.hidden_state = self.policy.initial_state(1) self._dummy_first = th.from_numpy(np.array((False,))).to(device) def load_weights(self, path): """Load model weights from a path, and reset hidden state""" self.policy.load_state_dict(th.load(path, map_location=self.device), strict=False) self.reset() def reset(self): """Reset agent to initial state (i.e., reset hidden state)""" self.hidden_state = self.policy.initial_state(1) def _video_obs_to_agent(self, video_frames): imgs = [resize_image(frame, AGENT_RESOLUTION) for frame in video_frames] # Add time and batch dim imgs = np.stack(imgs)[None] agent_input = {"img": th.from_numpy(imgs).to(self.device)} return agent_input def _agent_action_to_env(self, agent_action): """Turn output from policy into action for MineRL""" # This is quite important step (for some reason). # For the sake of your sanity, remember to do this step (manual conversion to numpy) # before proceeding. Otherwise, your agent might be a little derp. action = { "buttons": agent_action["buttons"].cpu().numpy(), "camera": agent_action["camera"].cpu().numpy() } minerl_action = self.action_mapper.to_factored(action) minerl_action_transformed = self.action_transformer.policy2env(minerl_action) return minerl_action_transformed def predict_actions(self, video_frames): """ Predict actions for a sequence of frames. `video_frames` should be of shape (N, H, W, C). Returns MineRL action dict, where each action head has shape (N, ...). Agent's hidden state is tracked internally. To reset it, call `reset()`. """ agent_input = self._video_obs_to_agent(video_frames) # The "first" argument could be used to reset tell episode # boundaries, but we are only using this for predicting (for now), # so we do not hassle with it yet. dummy_first = th.zeros((video_frames.shape[0], 1)).to(self.device) predicted_actions, self.hidden_state, _ = self.policy.predict( agent_input, first=dummy_first, state_in=self.hidden_state, deterministic=True ) predicted_minerl_action = self._agent_action_to_env(predicted_actions) return predicted_minerl_action # NOTE: this is _not_ the original code of IDM! # As such, while it is close and seems to function well, # its performance might be bit off from what is reported # in the paper. ENV_KWARGS = dict( fov_range=[70, 70], frameskip=1, gamma_range=[2, 2], guiscale_range=[1, 1], resolution=[640, 360], cursor_size_range=[16.0, 16.0], ) KEYBOARD_BUTTON_MAPPING = { "key.keyboard.escape" :"ESC", "key.keyboard.s" :"back", "key.keyboard.q" :"drop", "key.keyboard.w" :"forward", "key.keyboard.1" :"hotbar.1", "key.keyboard.2" :"hotbar.2", "key.keyboard.3" :"hotbar.3", "key.keyboard.4" :"hotbar.4", "key.keyboard.5" :"hotbar.5", "key.keyboard.6" :"hotbar.6", "key.keyboard.7" :"hotbar.7", "key.keyboard.8" :"hotbar.8", "key.keyboard.9" :"hotbar.9", "key.keyboard.e" :"inventory", "key.keyboard.space" :"jump", "key.keyboard.a" :"left", "key.keyboard.d" :"right", "key.keyboard.left.shift" :"sneak", "key.keyboard.left.control" :"sprint", "key.keyboard.f" :"swapHands", } # Template action NOOP_ACTION = { "ESC": 0, "back": 0, "drop": 0, "forward": 0, "hotbar.1": 0, "hotbar.2": 0, "hotbar.3": 0, "hotbar.4": 0, "hotbar.5": 0, "hotbar.6": 0, "hotbar.7": 0, "hotbar.8": 0, "hotbar.9": 0, "inventory": 0, "jump": 0, "left": 0, "right": 0, "sneak": 0, "sprint": 0, "swapHands": 0, "camera": np.array([0, 0]), "attack": 0, "use": 0, "pickItem": 0, } # Matches a number in the MineRL Java code regarding sensitivity # This is for mapping from recorded sensitivity to the one used in the model CAMERA_SCALER = 360.0 / 2400.0 def json_action_to_env_action(json_action): """ Converts a json action into a MineRL action. Returns (minerl_action, is_null_action) """ if "ESC" in json_action: return json_action, False # This might be slow... env_action = NOOP_ACTION.copy() # As a safeguard, make camera action again so we do not override anything env_action["camera"] = np.array([0, 0]) is_null_action = True keyboard_keys = json_action["keyboard"]["keys"] for key in keyboard_keys: # You can have keys that we do not use, so just skip them # NOTE in original training code, ESC was removed and replaced with # "inventory" action if GUI was open. # Not doing it here, as BASALT uses ESC to quit the game. if key in KEYBOARD_BUTTON_MAPPING: env_action[KEYBOARD_BUTTON_MAPPING[key]] = 1 is_null_action = False mouse = json_action["mouse"] camera_action = env_action["camera"] camera_action[0] = mouse["dy"] * CAMERA_SCALER camera_action[1] = mouse["dx"] * CAMERA_SCALER if mouse["dx"] != 0 or mouse["dy"] != 0: is_null_action = False else: if abs(camera_action[0]) > 180: camera_action[0] = 0 if abs(camera_action[1]) > 180: camera_action[1] = 0 mouse_buttons = mouse["buttons"] if 0 in mouse_buttons: env_action["attack"] = 1 is_null_action = False if 1 in mouse_buttons: env_action["use"] = 1 is_null_action = False if 2 in mouse_buttons: env_action["pickItem"] = 1 is_null_action = False return env_action, is_null_action def load_action_jsonl(json_path): with open(json_path) as json_file: json_lines = json_file.readlines() json_data = "[" + ",".join(json_lines) + "]" json_data = json.loads(json_data) return json_data # loss on frame - avg on video - avg on dataset def evaluate_IDM_quality(model, weights,jsonl_folder, video_folder, infer_demo_num, n_frames, output_file): """ Evaluate the quality of a IDM model on a dataset of videos. Args: video_folder (str): Path to the folder containing videos. model (str): Path to the '.model' file to be loaded. weights (str): Path to the '.weights' file to be loaded. n_batches (int): Number of batches to process. n_frames (int): Number of frames to process at a time. """ ## set up IDM model agent_parameters = pickle.load(open(model, "rb")) net_kwargs = agent_parameters["model"]["args"]["net"]["args"] pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"] pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"]) # pi_head_kwargs["temperature"] = 1.0 agent = IDMAgent(idm_net_kwargs=net_kwargs, pi_head_kwargs=pi_head_kwargs) agent.load_weights(weights) # Load video files video_files = os.listdir(video_folder) video_files = [f for f in video_files if f.endswith(".mp4")] video_files = sorted(video_files) video_files = [os.path.join(video_folder, f) for f in video_files] eval_num = min(500,len(video_files)) video_files = video_files[:eval_num] dataset_labels = {} camera_loss_list = [] for video_file in tqdm(video_files): json_file = os.path.join(jsonl_folder,os.path.basename(video_file).replace(".mp4",".jsonl")) # old implementation # action_loss,video_avg_loss,predicted_actions_list = eval_1_video(agent, video_file, json_file, infer_demo_num, n_frames) # load predicted actions and recorded actions predicted_actions,recorded_actions = idm_prediction(agent, video_file,json_file, infer_demo_num, n_frames) # construct labels subtasks_labels = define_exclusive_classification_task(predicted_actions,recorded_actions,calculate_hot_bar = False) for key in subtasks_labels: if key not in dataset_labels: dataset_labels[key] = {"pred_labels":[] , "rec_labels":[], "class_num":0} dataset_labels[key]["pred_labels"].append(subtasks_labels[key]["pred_labels"])# array dataset_labels[key]["rec_labels"].append(subtasks_labels[key]["rec_labels"]) # array dataset_labels[key]["class_num"] = subtasks_labels[key]["class_num"] camera_loss_list.append(camera_loss(predicted_actions,recorded_actions)["camera_bin_loss"]) dataset_results ={} for key in dataset_labels: pred_labels = np.stack(dataset_labels[key]["pred_labels"]).flatten() # [num_videos , num_frames] -> [video_num x frame_num] rec_labels = np.stack(dataset_labels[key]["rec_labels"]).flatten() # [num_videos , num_frames] -> [video_num x frame_num] dataset_results[key]=classification_metric(pred_labels, rec_labels, dataset_labels[key]["class_num"]) # import pdb;pdb.set_trace() metric_mean_on_task = {} metrics = ['precision_micro', 'recall_micro', 'f1_micro', 'precision_macro', 'recall_macro', 'f1_macro'] tasks = dataset_results.keys() for key in metrics: if key == "class_num": continue metric_mean_on_task[key] = np.mean([dataset_results[task][key] for task in tasks]) dataset_results["metric_mean_on_task"] = metric_mean_on_task dataset_results["metric_mean_on_task"]["camera_loss"] = np.mean(camera_loss_list) ## change all keys into str dataset_results = {str(k): v for k, v in dataset_results.items()} print(dataset_results) print("===========================================") print(f"{output_file} IDM Metric: {metric_mean_on_task}") with open(output_file, 'w') as f: f.write(json.dumps(dataset_results,indent=4) + "\n") def construct_classification_labels(idm_actions:dict[str, list[int]],action_name_keys: list[int],num_class:int) -> list[int]: """ convert original predicted actions to classification labels """ # construct a one-hot vector string to int label vec2cls = {"0"*(num_class-1):0} for i in range(num_class-1): key = "0"*i + "1" + "0"*(num_class-2-i) vec2cls[key] = i+1 # print(vec2cls) vec2cls['1'*(num_class-1)] = 0 # do all equal not do # vec2cls = {"00":0,"10":1,"01":2} # tested for class_num = 2 num_labels = idm_actions[action_name_keys[0]].size # assert same length: video_num x frame_per_video # if not single in first dim, we should perform flattn # construct one-hot vector idm_action_string = [[str(int(i)) for i in idm_actions[action_name].flatten()] for action_name in action_name_keys] try: labels = [vec2cls["".join([idm_action_string[j][i] for j in range(num_class-1)])] for i in range(num_labels)] except: conflicts_num = sum([ i=='1' and j=='1' for i,j in zip(idm_action_string[0],idm_action_string[1])]) print(f"detect conflict prediction: {conflicts_num}") return None labels = np.array(labels) return labels def define_exclusive_classification_task(predicted_actions:dict,recorded_actions:dict,calculate_hot_bar = False) -> dict: subtasks = {"multi_class":[("back","forward"),# 01,00,10, ("left","right"), ("sneak","sprint"), ], "binary_class":["use","attack","jump","drop"] } if calculate_hot_bar: subtasks["multi_class"]=[("hotbar.1","hotbar.2","hotbar.3","hotbar.4","hotbar.5","hotbar.6","hotbar.7","hotbar.8","hotbar.9")] subtasks_labels = {} for class_pair in subtasks["multi_class"]: class_num = len(class_pair) + 1 # len = 2 has 00 01 10 # convert to strings # convert to classification pred_labels = construct_classification_labels(predicted_actions, class_pair, class_num) rec_labels = construct_classification_labels(recorded_actions, class_pair, class_num) if pred_labels is None or rec_labels is None: print(f"detect conflict prediction: {pred_labels} and {rec_labels}") continue subtasks_labels[class_pair] = {"class_num":class_num, "pred_labels":pred_labels, "rec_labels":rec_labels } for binary_task in subtasks["binary_class"]: pred_labels = predicted_actions[binary_task] rec_labels = recorded_actions[binary_task] subtasks_labels[binary_task] = {"class_num":2, "pred_labels":pred_labels, "rec_labels":rec_labels } return subtasks_labels def classification_metric(pred_labels, rec_labels, class_num): ## compute macro and micro score for both tri classification and binary classification ## the difference between macro and micro precision and binary precision for binary task is : ## the binary precision only compute label with 1 ; but micro and marco compute 0, 1 and then average them ## to align with tri-classification we use average="macro" and average="micro" precision_micro = precision_score(rec_labels, pred_labels, average="micro", zero_division=0) recall_micro = recall_score(rec_labels, pred_labels, average="micro", zero_division=0) f1_micro = f1_score(rec_labels, pred_labels, average="micro", zero_division=0) precision_macro = precision_score(rec_labels, pred_labels, average="macro", zero_division=0) recall_macro = recall_score(rec_labels, pred_labels, average="macro", zero_division=0) f1_macro = f1_score(rec_labels, pred_labels, average="macro", zero_division=0) return { "precision_micro": precision_micro, "recall_micro": recall_micro, "f1_micro": f1_micro, "precision_macro": precision_macro, "recall_macro": recall_macro, "f1_macro": f1_macro, "class_num": class_num } def aggregate_actions(actions:list) -> dict: return_dict = {} for action in actions: for key in action: if key not in return_dict: return_dict[key] = [] return_dict[key].append(action[key]) for key in return_dict: return_dict[key] = np.array(return_dict[key]).reshape(-1) return return_dict def idm_prediction(agent, video_path,json_path, infer_demo_num, n_frames): th.cuda.empty_cache() full_json_data = load_action_jsonl(json_path) json_data = full_json_data[infer_demo_num:infer_demo_num+n_frames] recorded_actions = [json_action_to_env_action(i)[0] for i in json_data] recorded_actions = aggregate_actions(recorded_actions) frames = [] cap = cv2.VideoCapture(video_path) for _ in range(n_frames): ret, frame = cap.read() if not ret: print(f"[Error] loading frames in {video_path} returing {_}") return None,None # BGR -> RGB frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame) frames = np.stack(frames) predicted_actions = agent.predict_actions(frames) for key in predicted_actions: if key == "camera": continue predicted_actions[key] = np.array(predicted_actions[key]).reshape(-1) return predicted_actions,recorded_actions def camera_loss(predicted_actions,recorded_actions): from lib.actions import CameraQuantizer cam_quantizer = CameraQuantizer( camera_binsize=2, camera_maxval=10, mu=10, quantization_scheme="mu_law") # import pdb;pdb.set_trace() cam_pred_token=cam_quantizer.discretize(predicted_actions['camera'].reshape(-1)) cam_gt_token =cam_quantizer.discretize(np.array(recorded_actions['camera'])) camera_bin_loss = np.abs(cam_pred_token-cam_gt_token).mean() return { "camera_bin_loss":camera_bin_loss } if __name__ == "__main__": parser = ArgumentParser("Evaluate IDM quality for MC-LVM ") parser.add_argument("--weights", type=str, required=True, help="[IDM model config] Path to the '.weights' file to be loaded.") parser.add_argument("--model", type=str, required=True, help="[IDM model config] Path to the '.model' file to be loaded.") parser.add_argument("--jsonl-path", type=str, required=True, help="[Eval Config] Path to .jsonl contains actions.") parser.add_argument("--video-path", type=str, required=True, help="[Eval Config] Path to a .mp4 file.") parser.add_argument("--infer-demo-num", type=int, default=0, help="[Inference Config] Number of frames to skip before starting evaluation.") parser.add_argument("--n-frames", type=int, default=32, help="[Inference Config] Number of frames to generation.") parser.add_argument("--output-file", type=str, default="[Eval Config] output/action_loss.jsonl", help="[Eval Config] Path to save the action loss.") args = parser.parse_args() os.makedirs(os.path.dirname(args.output_file), exist_ok=True) evaluate_IDM_quality(args.model, args.weights,args.jsonl_path ,args.video_path, args.infer_demo_num,args.n_frames,args.output_file) ================================================ FILE: metrics/IDM/lib/__init__.py ================================================ ================================================ FILE: metrics/IDM/lib/action_head.py ================================================ import logging from typing import Any, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init from gym3.types import DictType, Discrete, Real, TensorType, ValType LOG0 = -100 def fan_in_linear(module: nn.Module, scale=1.0, bias=True): """Fan-in init""" module.weight.data *= scale / module.weight.norm(dim=1, p=2, keepdim=True) if bias: module.bias.data *= 0 class ActionHead(nn.Module): """Abstract base class for action heads compatible with forc""" def forward(self, input_data: torch.Tensor) -> Any: """ Just a forward pass through this head :returns pd_params - parameters describing the probability distribution """ raise NotImplementedError def logprob(self, action_sample: torch.Tensor, pd_params: torch.Tensor) -> torch.Tensor: """Logartithm of probability of sampling `action_sample` from a probability described by `pd_params`""" raise NotImplementedError def entropy(self, pd_params: torch.Tensor) -> torch.Tensor: """Entropy of this distribution""" raise NotImplementedError def sample(self, pd_params: torch.Tensor, deterministic: bool = False) -> Any: """ Draw a sample from probability distribution given by those params :param pd_params Parameters of a probability distribution :param deterministic Whether to return a stochastic sample or deterministic mode of a distribution """ raise NotImplementedError def kl_divergence(self, params_q: torch.Tensor, params_p: torch.Tensor) -> torch.Tensor: """KL divergence between two distribution described by these two params""" raise NotImplementedError class DiagGaussianActionHead(ActionHead): """ Action head where actions are normally distributed uncorrelated variables with specific means and variances. Means are calculated directly from the network while standard deviations are a parameter of this module """ LOG2PI = np.log(2.0 * np.pi) def __init__(self, input_dim: int, num_dimensions: int): super().__init__() self.input_dim = input_dim self.num_dimensions = num_dimensions self.linear_layer = nn.Linear(input_dim, num_dimensions) self.log_std = nn.Parameter(torch.zeros(num_dimensions), requires_grad=True) def reset_parameters(self): init.orthogonal_(self.linear_layer.weight, gain=0.01) init.constant_(self.linear_layer.bias, 0.0) def forward(self, input_data: torch.Tensor, mask=None) -> torch.Tensor: assert not mask, "Can not use a mask in a gaussian action head" means = self.linear_layer(input_data) # Unsqueeze many times to get to the same shape logstd = self.log_std[(None,) * (len(means.shape) - 1)] mean_view, logstd = torch.broadcast_tensors(means, logstd) return torch.stack([mean_view, logstd], dim=-1) def logprob(self, action_sample: torch.Tensor, pd_params: torch.Tensor) -> torch.Tensor: """Log-likelihood""" means = pd_params[..., 0] log_std = pd_params[..., 1] std = torch.exp(log_std) z_score = (action_sample - means) / std return -(0.5 * ((z_score ** 2 + self.LOG2PI).sum(dim=-1)) + log_std.sum(dim=-1)) def entropy(self, pd_params: torch.Tensor) -> torch.Tensor: """ Categorical distribution entropy calculation - sum probs * log(probs). In case of diagonal gaussian distribution - 1/2 log(2 pi e sigma^2) """ log_std = pd_params[..., 1] return (log_std + 0.5 * (self.LOG2PI + 1)).sum(dim=-1) def sample(self, pd_params: torch.Tensor, deterministic: bool = False) -> torch.Tensor: means = pd_params[..., 0] log_std = pd_params[..., 1] if deterministic: return means else: return torch.randn_like(means) * torch.exp(log_std) + means def kl_divergence(self, params_q: torch.Tensor, params_p: torch.Tensor) -> torch.Tensor: """ Categorical distribution KL divergence calculation KL(Q || P) = sum Q_i log (Q_i / P_i) Formula is: log(sigma_p) - log(sigma_q) + (sigma_q^2 + (mu_q - mu_p)^2))/(2 * sigma_p^2) """ means_q = params_q[..., 0] log_std_q = params_q[..., 1] means_p = params_p[..., 0] log_std_p = params_p[..., 1] std_q = torch.exp(log_std_q) std_p = torch.exp(log_std_p) kl_div = log_std_p - log_std_q + (std_q ** 2 + (means_q - means_p) ** 2) / (2.0 * std_p ** 2) - 0.5 return kl_div.sum(dim=-1, keepdim=True) class CategoricalActionHead(ActionHead): """Action head with categorical actions""" def __init__( self, input_dim: int, shape: Tuple[int], num_actions: int, builtin_linear_layer: bool = True, temperature: float = 1.0 ): super().__init__() self.input_dim = input_dim self.num_actions = num_actions self.output_shape = shape + (num_actions,) self.temperature = temperature if builtin_linear_layer: self.linear_layer = nn.Linear(input_dim, np.prod(self.output_shape)) else: assert ( input_dim == num_actions ), f"If input_dim ({input_dim}) != num_actions ({num_actions}), you need a linear layer to convert them." self.linear_layer = None def reset_parameters(self): if self.linear_layer is not None: init.orthogonal_(self.linear_layer.weight, gain=0.01) init.constant_(self.linear_layer.bias, 0.0) finit.fan_in_linear(self.linear_layer, scale=0.01) def forward(self, input_data: torch.Tensor, mask=None) -> Any: if self.linear_layer is not None: flat_out = self.linear_layer(input_data) else: flat_out = input_data shaped_out = flat_out.reshape(flat_out.shape[:-1] + self.output_shape) shaped_out /= self.temperature if mask is not None: shaped_out[~mask] = LOG0 # Convert to float32 to avoid RuntimeError: "log_softmax_lastdim_kernel_impl" not implemented for 'Half' return F.log_softmax(shaped_out.float(), dim=-1) def logprob(self, actions: torch.Tensor, logits: torch.Tensor) -> torch.Tensor: value = actions.long().unsqueeze(-1) value, log_pmf = torch.broadcast_tensors(value, logits) value = value[..., :1] result = log_pmf.gather(-1, value).squeeze(-1) # result is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it. for _ in self.output_shape[:-1]: result = result.sum(dim=-1) return result def entropy(self, logits: torch.Tensor) -> torch.Tensor: """Categorical distribution entropy calculation - sum probs * log(probs)""" probs = torch.exp(logits) entropy = -torch.sum(probs * logits, dim=-1) # entropy is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it. for _ in self.output_shape[:-1]: entropy = entropy.sum(dim=-1) return entropy def sample(self, logits: torch.Tensor, deterministic: bool = False) -> Any: if deterministic: return torch.argmax(logits, dim=-1) else: # Gumbel-Softmax trick. u = torch.rand_like(logits) # In float16, if you have around 2^{float_mantissa_bits} logits, sometimes you'll sample 1.0 # Then the log(-log(1.0)) will give -inf when it should give +inf # This is a silly hack to get around that. # This hack does not skew the probability distribution, because this event can't possibly win the argmax. u[u == 1.0] = 0.999 return torch.argmax(logits - torch.log(-torch.log(u)), dim=-1) def kl_divergence(self, logits_q: torch.Tensor, logits_p: torch.Tensor) -> torch.Tensor: """ Categorical distribution KL divergence calculation KL(Q || P) = sum Q_i log (Q_i / P_i) When talking about logits this is: sum exp(Q_i) * (Q_i - P_i) """ kl = (torch.exp(logits_q) * (logits_q - logits_p)).sum(-1, keepdim=True) # kl is per-entry, still of size self.output_shape; we need to reduce of the rest of it. for _ in self.output_shape[:-1]: kl = kl.sum(dim=-2) # dim=-2 because we use keepdim=True above. return kl class DictActionHead(nn.ModuleDict): """Action head with multiple sub-actions""" def reset_parameters(self): for subhead in self.values(): subhead.reset_parameters() def forward(self, input_data: torch.Tensor, **kwargs) -> Any: """ :param kwargs: each kwarg should be a dict with keys corresponding to self.keys() e.g. if this ModuleDict has submodules keyed by 'A', 'B', and 'C', we could call: forward(input_data, foo={'A': True, 'C': False}, bar={'A': 7}} Then children will be called with: A: forward(input_data, foo=True, bar=7) B: forward(input_data) C: forward(input_Data, foo=False) """ result = {} for head_name, subhead in self.items(): head_kwargs = { kwarg_name: kwarg[head_name] for kwarg_name, kwarg in kwargs.items() if kwarg is not None and head_name in kwarg } result[head_name] = subhead(input_data, **head_kwargs) return result def logprob(self, actions: torch.Tensor, logits: torch.Tensor) -> torch.Tensor: return sum(subhead.logprob(actions[k], logits[k]) for k, subhead in self.items()) def sample(self, logits: torch.Tensor, deterministic: bool = False) -> Any: return {k: subhead.sample(logits[k], deterministic) for k, subhead in self.items()} def entropy(self, logits: torch.Tensor) -> torch.Tensor: return sum(subhead.entropy(logits[k]) for k, subhead in self.items()) def kl_divergence(self, logits_q: torch.Tensor, logits_p: torch.Tensor) -> torch.Tensor: return sum(subhead.kl_divergence(logits_q[k], logits_p[k]) for k, subhead in self.items()) def make_action_head(ac_space: ValType, pi_out_size: int, temperature: float = 1.0): """Helper function to create an action head corresponding to the environment action space""" if isinstance(ac_space, TensorType): if isinstance(ac_space.eltype, Discrete): return CategoricalActionHead(pi_out_size, ac_space.shape, ac_space.eltype.n, temperature=temperature) elif isinstance(ac_space.eltype, Real): if temperature != 1.0: logging.warning("Non-1 temperature not implemented for DiagGaussianActionHead.") assert len(ac_space.shape) == 1, "Nontrivial shapes not yet implemented." return DiagGaussianActionHead(pi_out_size, ac_space.shape[0]) elif isinstance(ac_space, DictType): return DictActionHead({k: make_action_head(v, pi_out_size, temperature) for k, v in ac_space.items()}) raise NotImplementedError(f"Action space of type {type(ac_space)} is not supported") ================================================ FILE: metrics/IDM/lib/action_mapping.py ================================================ import abc import itertools from collections import OrderedDict from typing import Dict, List import numpy as np from gym3.types import DictType, Discrete, TensorType from lib.actions import Buttons class ActionMapping(abc.ABC): """Class that maps between the standard MC factored action space and a new one you define! :param n_camera_bins: Need to specify this to define the original ac space for stats code """ # This is the default buttons groups, it can be changed for your action space BUTTONS_GROUPS = OrderedDict( hotbar=["none"] + [f"hotbar.{i}" for i in range(1, 10)], fore_back=["none", "forward", "back"], left_right=["none", "left", "right"], sprint_sneak=["none", "sprint", "sneak"], use=["none", "use"], drop=["none", "drop"], attack=["none", "attack"], jump=["none", "jump"], ) def __init__(self, n_camera_bins: int = 11): assert n_camera_bins % 2 == 1, "n_camera_bins should be odd" self.n_camera_bins = n_camera_bins self.camera_null_bin = n_camera_bins // 2 self.stats_ac_space = DictType( **{ "buttons": TensorType(shape=(len(Buttons.ALL),), eltype=Discrete(2)), "camera": TensorType(shape=(2,), eltype=Discrete(n_camera_bins)), } ) @abc.abstractmethod def from_factored(self, ac: Dict) -> Dict: """Converts a factored action (ac) to the new space :param ac: Dictionary of actions that must have a batch dimension """ pass @abc.abstractmethod def to_factored(self, ac: Dict) -> Dict: """Converts an action in the new space (ac) to the factored action space. :param ac: Dictionary of actions that must have a batch dimension """ pass @abc.abstractmethod def get_action_space_update(self): """Return a magym (gym3) action space. This will be used to update the env action space.""" pass @abc.abstractmethod def get_zero_action(self): """Return the zero or null action for this action space""" pass def factored_buttons_to_groups(self, ac_buttons: np.ndarray, button_group: List[str]) -> List[str]: """For a mutually exclusive group of buttons in button_group, find which option in the group was chosen. Assumes that each button group has the option of 'none' meaning that no button in the group was pressed. :param ac_buttons: button actions from the factored action space. Should dims [B, len(Buttons.ALL)] :param button_group: List of buttons in a mutually exclusive group. Each item in the list should appear in Buttons.ALL except for the special case 'none' which means no button in the group was pressed. e.g. ['none', 'forward', 'back']. For now 'none' must be the first element of button_group Returns a list of length B, where each element is an item from button_group. """ assert ac_buttons.shape[1] == len( Buttons.ALL ), f"There should be {len(Buttons.ALL)} buttons in the factored buttons space" assert button_group[0] == "none", "This function only works if 'none' is in button_group" # Actions in ac_buttons with order according to button_group group_indices = [Buttons.ALL.index(b) for b in button_group if b != "none"] ac_choices = ac_buttons[:, group_indices] # Special cases for forward/back, left/right where mutual press means do neither if "forward" in button_group and "back" in button_group: ac_choices[np.all(ac_choices, axis=-1)] = 0 if "left" in button_group and "right" in button_group: ac_choices[np.all(ac_choices, axis=-1)] = 0 ac_non_zero = np.where(ac_choices) ac_choice = ["none" for _ in range(ac_buttons.shape[0])] # Iterate over the non-zero indices so that if two buttons in a group were pressed at the same time # we give priority to the button later in the group. E.g. if hotbar.1 and hotbar.2 are pressed during the same # timestep, hotbar.2 is marked as pressed for index, action in zip(ac_non_zero[0], ac_non_zero[1]): ac_choice[index] = button_group[action + 1] # the zero'th index will mean no button pressed return ac_choice class IDMActionMapping(ActionMapping): """For IDM, but essentially this is just an identity mapping""" def from_factored(self, ac: Dict) -> Dict: return ac def to_factored(self, ac: Dict) -> Dict: return ac def get_action_space_update(self): """Return a magym (gym3) action space. This will be used to update the env action space.""" return { "buttons": TensorType(shape=(len(Buttons.ALL),), eltype=Discrete(2)), "camera": TensorType(shape=(2,), eltype=Discrete(self.n_camera_bins)), } def get_zero_action(self): raise NotImplementedError() class CameraHierarchicalMapping(ActionMapping): """Buttons are joint as in ButtonsJointMapping, but now a camera on/off meta action is added into this joint space. When this meta action is triggered, the separate camera head chooses a camera action which is also now a joint space. :param n_camera_bins: number of camera bins in the factored space """ # Add camera meta action to BUTTONS_GROUPS BUTTONS_GROUPS = ActionMapping.BUTTONS_GROUPS.copy() BUTTONS_GROUPS["camera"] = ["none", "camera"] BUTTONS_COMBINATIONS = list(itertools.product(*BUTTONS_GROUPS.values())) + ["inventory"] BUTTONS_COMBINATION_TO_IDX = {comb: i for i, comb in enumerate(BUTTONS_COMBINATIONS)} BUTTONS_IDX_TO_COMBINATION = {i: comb for i, comb in enumerate(BUTTONS_COMBINATIONS)} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.camera_groups = OrderedDict( camera_x=[f"camera_x{i}" for i in range(self.n_camera_bins)], camera_y=[f"camera_y{i}" for i in range(self.n_camera_bins)], ) self.camera_combinations = list(itertools.product(*self.camera_groups.values())) self.camera_combination_to_idx = {comb: i for i, comb in enumerate(self.camera_combinations)} self.camera_idx_to_combination = {i: comb for i, comb in enumerate(self.camera_combinations)} self.camera_null_idx = self.camera_combination_to_idx[ (f"camera_x{self.camera_null_bin}", f"camera_y{self.camera_null_bin}") ] self._null_action = { "buttons": self.BUTTONS_COMBINATION_TO_IDX[tuple("none" for _ in range(len(self.BUTTONS_GROUPS)))] } self._precompute_to_factored() def _precompute_to_factored(self): """Precompute the joint action -> factored action matrix.""" button_dim = self.stats_ac_space["buttons"].size self.BUTTON_IDX_TO_FACTORED = np.zeros((len(self.BUTTONS_IDX_TO_COMBINATION), button_dim), dtype=int) self.BUTTON_IDX_TO_CAMERA_META_OFF = np.zeros((len(self.BUTTONS_IDX_TO_COMBINATION)), dtype=bool) self.CAMERA_IDX_TO_FACTORED = np.zeros((len(self.camera_idx_to_combination), 2), dtype=int) # Pre compute Buttons for jnt_ac, button_comb in self.BUTTONS_IDX_TO_COMBINATION.items(): new_button_ac = np.zeros(len(Buttons.ALL), dtype="i") if button_comb == "inventory": new_button_ac[Buttons.ALL.index("inventory")] = 1 else: for group_choice in button_comb[:-1]: # Last one is camera if group_choice != "none": new_button_ac[Buttons.ALL.index(group_choice)] = 1 if button_comb[-1] != "camera": # This means camera meta action is off self.BUTTON_IDX_TO_CAMERA_META_OFF[jnt_ac] = True self.BUTTON_IDX_TO_FACTORED[jnt_ac] = new_button_ac # Pre compute camera for jnt_ac, camera_comb in self.camera_idx_to_combination.items(): new_camera_ac = np.ones((2), dtype="i") * self.camera_null_bin new_camera_ac[0] = self.camera_groups["camera_x"].index(camera_comb[0]) new_camera_ac[1] = self.camera_groups["camera_y"].index(camera_comb[1]) self.CAMERA_IDX_TO_FACTORED[jnt_ac] = new_camera_ac def from_factored(self, ac: Dict) -> Dict: """Converts a factored action (ac) to the new space. Assumes ac has a batch dim""" assert ac["camera"].ndim == 2, f"bad camera label, {ac['camera']}" assert ac["buttons"].ndim == 2, f"bad buttons label, {ac['buttons']}" # Get button choices for everything but camera choices_by_group = OrderedDict( (k, self.factored_buttons_to_groups(ac["buttons"], v)) for k, v in self.BUTTONS_GROUPS.items() if k != "camera" ) # Set camera "on off" action based on whether non-null camera action was given camera_is_null = np.all(ac["camera"] == self.camera_null_bin, axis=1) choices_by_group["camera"] = ["none" if is_null else "camera" for is_null in camera_is_null] new_button_ac = [] new_camera_ac = [] for i in range(ac["buttons"].shape[0]): # Buttons key = tuple([v[i] for v in choices_by_group.values()]) if ac["buttons"][i, Buttons.ALL.index("inventory")] == 1: key = "inventory" new_button_ac.append(self.BUTTONS_COMBINATION_TO_IDX[key]) # Camera -- inventory is also exclusive with camera if key == "inventory": key = ( f"camera_x{self.camera_null_bin}", f"camera_y{self.camera_null_bin}", ) else: key = (f"camera_x{ac['camera'][i][0]}", f"camera_y{ac['camera'][i][1]}") new_camera_ac.append(self.camera_combination_to_idx[key]) return dict( buttons=np.array(new_button_ac)[:, None], camera=np.array(new_camera_ac)[:, None], ) def to_factored(self, ac: Dict) -> Dict: """Converts an action in the new space (ac) to the factored action space. Assumes ac has a batch dim""" assert ac["camera"].shape[-1] == 1 assert ac["buttons"].shape[-1] == 1 new_button_ac = self.BUTTON_IDX_TO_FACTORED[np.squeeze(ac["buttons"], -1)] camera_off = self.BUTTON_IDX_TO_CAMERA_META_OFF[np.squeeze(ac["buttons"], -1)] new_camera_ac = self.CAMERA_IDX_TO_FACTORED[np.squeeze(ac["camera"], -1)] new_camera_ac[camera_off] = self.camera_null_bin return dict(buttons=new_button_ac, camera=new_camera_ac) def get_action_space_update(self): return { "camera": TensorType(shape=(1,), eltype=Discrete(len(self.camera_combinations))), "buttons": TensorType(shape=(1,), eltype=Discrete(len(self.BUTTONS_COMBINATIONS))), } def get_zero_action(self): return self._null_action ================================================ FILE: metrics/IDM/lib/actions.py ================================================ import attr # import minerl.herobraine.hero.mc as mc import numpy as np from lib.minecraft_util import store_args class Buttons: ATTACK = "attack" BACK = "back" FORWARD = "forward" JUMP = "jump" LEFT = "left" RIGHT = "right" SNEAK = "sneak" SPRINT = "sprint" USE = "use" DROP = "drop" INVENTORY = "inventory" ALL = [ ATTACK, BACK, FORWARD, JUMP, LEFT, RIGHT, SNEAK, SPRINT, USE, DROP, INVENTORY, ] + [f"hotbar.{i}" for i in range(1, 10)] class SyntheticButtons: # Composite / scripted actions CHANNEL_ATTACK = "channel-attack" ALL = [CHANNEL_ATTACK] class QuantizationScheme: LINEAR = "linear" MU_LAW = "mu_law" @attr.s(auto_attribs=True) class CameraQuantizer: """ A camera quantizer that discretizes and undiscretizes a continuous camera input with y (pitch) and x (yaw) components. Parameters: - camera_binsize: The size of the bins used for quantization. In case of mu-law quantization, it corresponds to the average binsize. - camera_maxval: The maximum value of the camera action. - quantization_scheme: The quantization scheme to use. Currently, two quantization schemes are supported: - Linear quantization (default): Camera actions are split uniformly into discrete bins - Mu-law quantization: Transforms the camera action using mu-law encoding (https://en.wikipedia.org/wiki/%CE%9C-law_algorithm) followed by the same quantization scheme used by the linear scheme. - mu: Mu is the parameter that defines the curvature of the mu-law encoding. Higher values of mu will result in a sharper transition near zero. Below are some reference values listed for choosing mu given a constant maxval and a desired max_precision value. maxval = 10 | max_precision = 0.5 | μ ≈ 2.93826 maxval = 10 | max_precision = 0.4 | μ ≈ 4.80939 maxval = 10 | max_precision = 0.25 | μ ≈ 11.4887 maxval = 20 | max_precision = 0.5 | μ ≈ 2.7 maxval = 20 | max_precision = 0.4 | μ ≈ 4.39768 maxval = 20 | max_precision = 0.25 | μ ≈ 10.3194 maxval = 40 | max_precision = 0.5 | μ ≈ 2.60780 maxval = 40 | max_precision = 0.4 | μ ≈ 4.21554 maxval = 40 | max_precision = 0.25 | μ ≈ 9.81152 """ camera_maxval: int camera_binsize: int quantization_scheme: str = attr.ib( default=QuantizationScheme.LINEAR, validator=attr.validators.in_([QuantizationScheme.LINEAR, QuantizationScheme.MU_LAW]), ) mu: float = attr.ib(default=5) def discretize(self, xy): xy = np.clip(xy, -self.camera_maxval, self.camera_maxval) if self.quantization_scheme == QuantizationScheme.MU_LAW: xy = xy / self.camera_maxval v_encode = np.sign(xy) * (np.log(1.0 + self.mu * np.abs(xy)) / np.log(1.0 + self.mu)) v_encode *= self.camera_maxval xy = v_encode # Quantize using linear scheme return np.round((xy + self.camera_maxval) / self.camera_binsize).astype(np.int64) def undiscretize(self, xy): xy = xy * self.camera_binsize - self.camera_maxval if self.quantization_scheme == QuantizationScheme.MU_LAW: xy = xy / self.camera_maxval v_decode = np.sign(xy) * (1.0 / self.mu) * ((1.0 + self.mu) ** np.abs(xy) - 1.0) v_decode *= self.camera_maxval xy = v_decode return xy class ActionTransformer: """Transforms actions between internal array and minerl env format.""" @store_args def __init__( self, camera_maxval=10, camera_binsize=2, camera_quantization_scheme="linear", camera_mu=5, ): self.quantizer = CameraQuantizer( camera_maxval=camera_maxval, camera_binsize=camera_binsize, quantization_scheme=camera_quantization_scheme, mu=camera_mu, ) def camera_zero_bin(self): return self.camera_maxval // self.camera_binsize def discretize_camera(self, xy): return self.quantizer.discretize(xy) def undiscretize_camera(self, pq): return self.quantizer.undiscretize(pq) def item_embed_id_to_name(self, item_id): return mc.MINERL_ITEM_MAP[item_id] def dict_to_numpy(self, acs): """ Env format to policy output format. """ act = { "buttons": np.stack([acs.get(k, 0) for k in Buttons.ALL], axis=-1), "camera": self.discretize_camera(acs["camera"]), } if not self.human_spaces: act.update( { "synthetic_buttons": np.stack([acs[k] for k in SyntheticButtons.ALL], axis=-1), "place": self.item_embed_name_to_id(acs["place"]), "equip": self.item_embed_name_to_id(acs["equip"]), "craft": self.item_embed_name_to_id(acs["craft"]), } ) return act def numpy_to_dict(self, acs): """ Numpy policy output to env-compatible format. """ assert acs["buttons"].shape[-1] == len( Buttons.ALL ), f"Mismatched actions: {acs}; expected {len(Buttons.ALL)}:\n( {Buttons.ALL})" out = {name: acs["buttons"][..., i] for (i, name) in enumerate(Buttons.ALL)} out["camera"] = self.undiscretize_camera(acs["camera"]) return out def policy2env(self, acs): acs = self.numpy_to_dict(acs) return acs def env2policy(self, acs): nbatch = acs["camera"].shape[0] dummy = np.zeros((nbatch,)) out = { "camera": self.discretize_camera(acs["camera"]), "buttons": np.stack([acs.get(k, dummy) for k in Buttons.ALL], axis=-1), } return out ================================================ FILE: metrics/IDM/lib/impala_cnn.py ================================================ import math from copy import deepcopy from typing import Dict, List, Optional from torch import nn from torch.nn import functional as F from lib import misc from lib import torch_util as tu from lib.util import FanInInitReLULayer class CnnBasicBlock(nn.Module): """ Residual basic block, as in ImpalaCNN. Preserves channel number and shape :param inchan: number of input channels :param init_scale: weight init scale multiplier """ def __init__( self, inchan: int, init_scale: float = 1, log_scope="", init_norm_kwargs: Dict = {}, **kwargs, ): super().__init__() self.inchan = inchan s = math.sqrt(init_scale) self.conv0 = FanInInitReLULayer( self.inchan, self.inchan, kernel_size=3, padding=1, init_scale=s, log_scope=f"{log_scope}/conv0", **init_norm_kwargs, ) self.conv1 = FanInInitReLULayer( self.inchan, self.inchan, kernel_size=3, padding=1, init_scale=s, log_scope=f"{log_scope}/conv1", **init_norm_kwargs, ) def forward(self, x): x = x + self.conv1(self.conv0(x)) return x class CnnDownStack(nn.Module): """ Downsampling stack from Impala CNN. :param inchan: number of input channels :param nblock: number of residual blocks after downsampling :param outchan: number of output channels :param init_scale: weight init scale multiplier :param pool: if true, downsample with max pool :param post_pool_groups: if not None, normalize with group norm with this many groups :param kwargs: remaining kwargs are passed into the blocks and layers """ name = "Impala_CnnDownStack" def __init__( self, inchan: int, nblock: int, outchan: int, init_scale: float = 1, pool: bool = True, post_pool_groups: Optional[int] = None, log_scope: str = "", init_norm_kwargs: Dict = {}, first_conv_norm=False, **kwargs, ): super().__init__() self.inchan = inchan self.outchan = outchan self.pool = pool first_conv_init_kwargs = deepcopy(init_norm_kwargs) if not first_conv_norm: first_conv_init_kwargs["group_norm_groups"] = None first_conv_init_kwargs["batch_norm"] = False self.firstconv = FanInInitReLULayer( inchan, outchan, kernel_size=3, padding=1, log_scope=f"{log_scope}/firstconv", **first_conv_init_kwargs, ) self.post_pool_groups = post_pool_groups if post_pool_groups is not None: self.n = nn.GroupNorm(post_pool_groups, outchan) self.blocks = nn.ModuleList( [ CnnBasicBlock( outchan, init_scale=init_scale / math.sqrt(nblock), log_scope=f"{log_scope}/block{i}", init_norm_kwargs=init_norm_kwargs, **kwargs, ) for i in range(nblock) ] ) def forward(self, x): x = self.firstconv(x) if self.pool: x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) if self.post_pool_groups is not None: x = self.n(x) x = tu.sequential(self.blocks, x, diag_name=self.name) return x def output_shape(self, inshape): c, h, w = inshape assert c == self.inchan if self.pool: return (self.outchan, (h + 1) // 2, (w + 1) // 2) else: return (self.outchan, h, w) class ImpalaCNN(nn.Module): """ :param inshape: input image shape (height, width, channels) :param chans: number of residual downsample stacks. Each element is the number of filters per convolution in the stack :param outsize: output hidden size :param nblock: number of residual blocks per stack. Each block has 2 convs and a residual :param init_norm_kwargs: arguments to be passed to convolutional layers. Options can be found in ypt.model.util:FanInInitReLULayer :param dense_init_norm_kwargs: arguments to be passed to convolutional layers. Options can be found in ypt.model.util:FanInInitReLULayer :param kwargs: remaining kwargs are passed into the CnnDownStacks """ name = "ImpalaCNN" def __init__( self, inshape: List[int], chans: List[int], outsize: int, nblock: int, init_norm_kwargs: Dict = {}, dense_init_norm_kwargs: Dict = {}, first_conv_norm=False, **kwargs, ): super().__init__() h, w, c = inshape curshape = (c, h, w) self.stacks = nn.ModuleList() for i, outchan in enumerate(chans): stack = CnnDownStack( curshape[0], nblock=nblock, outchan=outchan, init_scale=math.sqrt(len(chans)), log_scope=f"downstack{i}", init_norm_kwargs=init_norm_kwargs, first_conv_norm=first_conv_norm if i == 0 else True, **kwargs, ) self.stacks.append(stack) curshape = stack.output_shape(curshape) self.dense = FanInInitReLULayer( misc.intprod(curshape), outsize, layer_type="linear", log_scope="imapala_final_dense", init_scale=1.4, **dense_init_norm_kwargs, ) self.outsize = outsize def forward(self, x): b, t = x.shape[:-3] x = x.reshape(b * t, *x.shape[-3:]) x = misc.transpose(x, "bhwc", "bchw") x = tu.sequential(self.stacks, x, diag_name=self.name) x = x.reshape(b, t, *x.shape[1:]) x = tu.flatten_image(x) x = self.dense(x) return x ================================================ FILE: metrics/IDM/lib/masked_attention.py ================================================ import functools import torch as th from torch import nn import lib.xf as xf from lib.minecraft_util import store_args from lib.tree_util import tree_map @functools.lru_cache() def get_band_diagonal_mask(t: int, T: int, maxlen: int, batchsize: int, device: th.device) -> th.Tensor: """Returns a band diagonal mask which is causal (upper triangle is masked) and such that any frame can only view up to maxlen total past frames including the current frame. Example Masks: Here 0 means that frame is masked and we mask it by adding a huge number to the attention logits (see orc.xf) t = 3, T = 3, maxlen = 3 T t 1 0 0 | mask out T > t 1 1 0 | 1 1 1 | t = 3, T = 6, maxlen = 3 t 0 1 1 1 0 0 | mask out T > t 0 0 1 1 1 0 | 0 0 0 1 1 1 | Args: t: number of rows (presumably number of frames recieving gradient) T: number of cols (presumably t + past context that isn't being gradient updated) maxlen: maximum number of frames (including current frame) any frame can attend to batchsize: number of masks to return device: torch device to place mask on Returns: Boolean mask of shape (batchsize, t, T) """ m = th.ones(t, T, dtype=bool) m.tril_(T - t) # Mask out upper triangle if maxlen is not None and maxlen < T: # Mask out lower triangle m.triu_(T - t - maxlen + 1) m_btT = m[None].repeat_interleave(batchsize, dim=0) m_btT = m_btT.to(device=device) return m_btT def get_mask(first_b11: th.Tensor, state_mask: th.Tensor, t: int, T: int, maxlen: int, heads: int, device) -> th.Tensor: """Returns a band diagonal mask that respects masking past states (columns 0:T-t inclusive) if first_b11 is True. See get_band_diagonal_mask for how the base mask is computed. This function takes that mask and first zeros out any past context if first_b11 is True. Say our context is in chunks of length t (so here T = 4t). We see that in the second batch we recieved first=True context t t t t first F T F F Now, given this the mask should mask out anything prior to T < t; however since we don't have access to the past first_b11's we need to keep a state of the mask at those past timesteps. This is what state_mask is. In particular state_mask is a [b, t, T - t] mask matrix that contains the mask for the past T - t frames. Args: (See get_band_diagonal_mask for remaining args) first_b11: boolean tensor with shape [batchsize, 1, 1] indicating if the first timestep for each batch element had first=True state_mask: mask tensor of shape [b, t, T - t] t: number of mask rows (presumably number of frames for which we take gradient) T: number of mask columns (t + the number of past frames we keep in context) maxlen: actual context length heads: number of attention heads device: torch device Returns: m_btT: Boolean mask of shape (batchsize * heads, t, T) state_mask: updated state_mask """ b = first_b11.shape[0] if state_mask is None: state_mask = th.zeros((b, 1, T - t), dtype=bool, device=device) m_btT = get_band_diagonal_mask(t, T, maxlen, b, device).clone() # Should be shape B, t, T not_first = ~first_b11.to(device=device) m_btT[:, :, :-t] &= not_first # Zero out anything in the past if first is true m_btT[:, :, :-t] &= state_mask m_bhtT = m_btT[:, None].repeat_interleave(heads, dim=1) m_btT = m_bhtT.reshape((b * heads), t, T) # Update state_mask such that it reflects the most recent first state_mask = th.cat( [ state_mask[:, :, t:] & not_first, th.ones((b, 1, min(t, T - t)), dtype=bool, device=device), ], dim=-1, ) return m_btT, state_mask class MaskedAttention(nn.Module): """ Transformer self-attention layer that removes frames from previous episodes from the hidden state under certain constraints. The constraints are: - The "first" flag can only be true for the first timestep of each batch. An assert will fire if other timesteps have first = True. input_size: The dimension of the input (which also happens to be the size of the output) memory_size: The number of frames to keep in the inner state. Note that when attending, we will be able to attend to both the frames in the inner state (which presumably won't have gradients anymore) and the frames in the batch. "mask" for some additional considerations on this. heads: The number of attention heads to use. Note that we will split the input into this number of heads, so input_size needs to be divisible by heads. timesteps: number of timesteps with which we'll be taking gradient mask: Can be "none" or "clipped_causal". "clipped_causal" is a normal causal mask but solves the following minor problem: if you have a state of length 128 and a batch of 128 frames, then the first frame of your batch will be able to attend to 128 previous frames, but the last one will be able to attend to 255 previous frames. In this example, "clipped_causal" will make it so that the last frame can only attend to 128 previous frames, so that there is no bias coming from the position in the batch. None simply allows you to attend to any frame in the state + batch, which means you can also attend to future frames. """ @store_args def __init__( self, input_size, memory_size: int, heads: int, timesteps: int, mask: str = "clipped_causal", init_scale=1, norm="none", log_scope="sa", use_muP_factor=False, ): super().__init__() assert mask in {"none", "clipped_causal"} assert memory_size >= 0 self.maxlen = memory_size - timesteps if mask == "none": mask = None self.orc_attn = xf.All2All(heads, self.maxlen, mask=mask is not None) self.orc_block = xf.SelfAttentionLayer( input_size, self.orc_attn, scale=init_scale, relattn=True, cache_keep_len=self.maxlen, norm=norm, log_scope=log_scope, use_muP_factor=use_muP_factor, ) def initial_state(self, batchsize: int, device=None): """Return the initial state mask (None) and the initial state of the transformer (zerod out keys and queries)""" state = self.orc_block.initial_state(batchsize, initial_T=self.maxlen) state_mask = None if device is not None: state = tree_map(lambda x: x.to(device), state) return state_mask, state def forward(self, input_bte, first_bt, state): """Forward propagation of a single layer""" state_mask, xf_state = state t = first_bt.shape[1] if self.mask == "clipped_causal": new_mask, state_mask = get_mask( first_b11=first_bt[:, [[0]]], state_mask=state_mask, t=t, T=t + self.maxlen, maxlen=self.maxlen, heads=self.heads, device=input_bte.device, ) self.orc_block.attn.mask = new_mask output, xf_state = self.orc_block(input_bte, xf_state) return output, (state_mask, xf_state) def get_log_keys(self): # These are logged in xf.SelfAttentionLayer return [f"activation_{stat}/{self.log_scope}/{k}" for k in ["K", "Q", "V", "A", "Aproj"] for stat in ["mean", "std"]] ================================================ FILE: metrics/IDM/lib/minecraft_util.py ================================================ import functools import inspect from typing import Optional, Tuple import numpy as np import torch from lib.action_head import (CategoricalActionHead, DiagGaussianActionHead, DictActionHead) def store_args(method): """Stores provided method args as instance attributes.""" argspec = inspect.getfullargspec(method) defaults = {} if argspec.defaults is not None: defaults = dict(zip(argspec.args[-len(argspec.defaults) :], argspec.defaults)) if argspec.kwonlydefaults is not None: defaults.update(argspec.kwonlydefaults) arg_names = argspec.args[1:] @functools.wraps(method) def wrapper(*positional_args, **keyword_args): self = positional_args[0] # Get default arg values args = defaults.copy() # Add provided arg values for name, value in zip(arg_names, positional_args[1:]): args[name] = value args.update(keyword_args) self.__dict__.update(args) return method(*positional_args, **keyword_args) return wrapper def get_norm_entropy_from_cat_head(module, name, masks, logits): # Note that the mask has already been applied to the logits at this point entropy = -torch.sum(torch.exp(logits) * logits, dim=-1) if name in masks: n = torch.sum(masks[name], dim=-1, dtype=torch.float) norm_entropy = entropy / torch.log(n) # When the mask only allows one option the normalized entropy makes no sense # as it is basically both maximal (the distribution is as uniform as it can be) # and minimal (there is no variance at all). # A such, we ignore them for purpose of calculating entropy. zero = torch.zeros_like(norm_entropy) norm_entropy = torch.where(n.eq(1.0), zero, norm_entropy) count = n.not_equal(1.0).int() else: n = torch.tensor(logits.shape[-1], dtype=torch.float) norm_entropy = entropy / torch.log(n) count = torch.ones_like(norm_entropy, dtype=torch.int) # entropy is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it. for _ in module.output_shape[:-1]: norm_entropy = norm_entropy.sum(dim=-1) count = count.sum(dim=-1) return norm_entropy, count def get_norm_cat_entropy(module, masks, logits, template) -> Tuple[torch.Tensor, torch.Tensor]: entropy_sum = torch.zeros_like(template, dtype=torch.float) counts = torch.zeros_like(template, dtype=torch.int) for k, subhead in module.items(): if isinstance(subhead, DictActionHead): entropy, count = get_norm_cat_entropy(subhead, masks, logits[k], template) elif isinstance(subhead, CategoricalActionHead): entropy, count = get_norm_entropy_from_cat_head(subhead, k, masks, logits[k]) else: continue entropy_sum += entropy counts += count return entropy_sum, counts def get_diag_guassian_entropy(module, logits, template) -> Optional[torch.Tensor]: entropy_sum = torch.zeros_like(template, dtype=torch.float) count = torch.zeros(1, device=template.device, dtype=torch.int) for k, subhead in module.items(): if isinstance(subhead, DictActionHead): entropy_sum += get_diag_guassian_entropy(subhead, logits[k], template) elif isinstance(subhead, DiagGaussianActionHead): entropy_sum += module.entropy(logits) else: continue count += 1 return entropy_sum / count ================================================ FILE: metrics/IDM/lib/misc.py ================================================ import numpy as np import torch as th def intprod(xs): """ Product of a sequence of integers """ out = 1 for x in xs: out *= x return out def safezip(*args): """ Check that lengths of sequences are the same, then zip them """ args = [list(a) for a in args] n = len(args[0]) for arg in args[1:]: assert len(arg) == n, f"length mismatch: {list(map(len, args))}" return list(zip(*args)) def transpose(x, before, after): """ Usage: x_bca = transpose(x_abc, 'abc', 'bca') """ assert sorted(before) == sorted(after), f"cannot transpose {before} to {after}" assert x.ndim == len( before ), f"before spec '{before}' has length {len(before)} but x has {x.ndim} dimensions: {tuple(x.shape)}" return x.permute(tuple(before.index(i) for i in after)) def transpose_undo(x, before, after, *, undo=None): """ Usage: x_bca, undo = transpose_undo(x_abc, 'abc', 'bca') x_bca = fully_connected_layer(x_bca) x_abc = undo(x_bca) """ return ( transpose(x, before, after), compose_undo(undo, lambda x: transpose(x, before=after, after=before)), ) def compose_undo(u1, u2): assert u2 is not None if u1 is None: return u2 def u(x): x = u2(x) x = u1(x) return x return u NO_BIND = "__nobind" def _parse_reshape_str(s, kind): assert kind in ("before", "after") result = [] n_underscores = 0 for i, part in enumerate(s.split(",")): part = part.strip() if part == "?" and kind == "before": result.append([f"__{i}"]) elif part == "_": result.append([f"{NO_BIND}_{n_underscores}"]) n_underscores += 1 else: result.append([term.strip() for term in part.split("*")]) return result def _infer_part(part, concrete_dim, known, index, full_shape): if type(part) is int: return part assert isinstance(part, list), part lits = [] syms = [] for term in part: if type(term) is int: lits.append(term) elif type(term) is str: syms.append(term) else: raise TypeError(f"got {type(term)} but expected int or str") int_part = 1 for x in lits: int_part *= x if len(syms) == 0: return int_part elif len(syms) == 1 and concrete_dim is not None: assert concrete_dim % int_part == 0, f"{concrete_dim} % {int_part} != 0 (at index {index}, full shape is {full_shape})" v = concrete_dim // int_part if syms[0] in known: assert ( known[syms[0]] == v ), f"known value for {syms[0]} is {known[syms[0]]} but found value {v} at index {index} (full shape is {full_shape})" else: known[syms[0]] = v return concrete_dim else: for i in range(len(syms)): if syms[i] in known: syms[i] = known[syms[i]] else: try: syms[i] = int(syms[i]) except ValueError: pass return lits + syms def _infer_step(args): known, desc, shape = args new_known = known.copy() new_desc = desc.copy() for i in range(len(desc)): if shape is None: concrete_dim = None else: concrete_dim = shape[i] new_desc[i] = _infer_part(part=desc[i], concrete_dim=concrete_dim, known=new_known, index=i, full_shape=shape) return new_known, new_desc, shape def _infer(known, desc, shape): if shape is not None: assert len(desc) == len(shape), f"desc has length {len(desc)} but shape has length {len(shape)} (shape={shape})" known, desc, shape = fixed_point(_infer_step, (known, desc, shape)) return desc, known def fixed_point(f, x, eq=None): if eq is None: eq = lambda a, b: a == b while True: new_x = f(x) if eq(x, new_x): return x else: x = new_x def _infer_question_mark(x, total_product): try: question_mark_index = x.index(["?"]) except ValueError: return x observed_product = 1 for i in range(len(x)): if i != question_mark_index: assert type(x[i]) is int, f"when there is a question mark, there can be no other unknown values (full list: {x})" observed_product *= x[i] assert ( observed_product and total_product % observed_product == 0 ), f"{total_product} is not divisible by {observed_product}" value = total_product // observed_product x = x.copy() x[question_mark_index] = value return x def _ground(x, known, infer_question_mark_with=None): x, known = _infer(known=known, desc=x, shape=None) if infer_question_mark_with: x = _infer_question_mark(x, infer_question_mark_with) for part in x: assert type(part) is int, f"cannot infer value of {part}" return x def _handle_ellipsis(x, before, after): ell = ["..."] try: i = before.index(ell) l = len(x.shape) - len(before) + 1 ellipsis_value = x.shape[i : i + l] ellipsis_value = list(ellipsis_value) before = before[:i] + ellipsis_value + before[i + 1 :] except ValueError: pass try: i = after.index(ell) after = after[:i] + ellipsis_value + after[i + 1 :] except ValueError: pass except UnboundLocalError as e: raise ValueError("there cannot be an ellipsis in 'after' unless there is an ellipsis in 'before'") from e return before, after def reshape_undo(inp, before, after, *, undo=None, known=None, **kwargs): """ Usage: x_Bhwse, undo = reshape_undo( x_bthwe, 'b, t, ..., stride*e', 'b*t, ..., stride, e', stride=7 ) x_Bhwse = do_some_stuff(x_Bhwse) x_bthwe = undo(x_Bhwse) It's necessary to pass known values as keywords only when they can't be inferred from the shape. (Eg. in the above example we needed to pass stride but not b, t, or e, since those can be determined from inp.shape once stride is known.) """ if known: known = {**kwargs, **known} else: known = kwargs assert type(before) is type(after), f"{type(before)} != {type(after)}" assert isinstance(inp, (th.Tensor, np.ndarray)), f"require tensor or ndarray but got {type(inp)}" assert isinstance(before, (str, list)), f"require str or list but got {type(before)}" if isinstance(before, str): before = _parse_reshape_str(before, "before") after = _parse_reshape_str(after, "after") before, after = _handle_ellipsis(inp, before, after) before_saved, after_saved = before, after before, known = _infer(known=known, desc=before, shape=inp.shape) before = _ground(before, known, product(inp.shape)) after = _ground(after, known, product(inp.shape)) known = {k: v for k, v in known.items() if not k.startswith(NO_BIND)} assert tuple(inp.shape) == tuple(before), f"expected shape {before} but got shape {inp.shape}" assert product(inp.shape) == product( after ), f"cannot reshape {inp.shape} to {after} because the number of elements does not match" return ( inp.reshape(after), compose_undo(undo, lambda inp: reshape(inp, after_saved, before_saved, known=known)), ) def reshape(*args, **kwargs): """ Please see the documenation for reshape_undo. """ x, _ = reshape_undo(*args, **kwargs) return x def product(xs, one=1): result = one for x in xs: result = result * x return result def exact_div(a, b): assert a % b == 0, f"{a} is not divisible by {b}" return a // b ================================================ FILE: metrics/IDM/lib/mlp.py ================================================ import torch as th from torch import nn from lib import misc from lib import torch_util as tu class MLP(nn.Module): def __init__(self, insize, nhidlayer, outsize, hidsize, hidactiv, dtype=th.float32): super().__init__() self.insize = insize self.nhidlayer = nhidlayer self.outsize = outsize in_sizes = [insize] + [hidsize] * nhidlayer out_sizes = [hidsize] * nhidlayer + [outsize] self.layers = nn.ModuleList( [tu.NormedLinear(insize, outsize, dtype=dtype) for (insize, outsize) in misc.safezip(in_sizes, out_sizes)] ) self.hidactiv = hidactiv def forward(self, x): *hidlayers, finallayer = self.layers for layer in hidlayers: x = layer(x) x = self.hidactiv(x) x = finallayer(x) return x @property def output_shape(self): return (self.outsize,) ================================================ FILE: metrics/IDM/lib/normalize_ewma.py ================================================ import numpy as np import torch import torch.nn as nn class NormalizeEwma(nn.Module): """Normalize a vector of observations - across the first norm_axes dimensions""" def __init__(self, input_shape, norm_axes=2, beta=0.99999, per_element_update=False, epsilon=1e-5): super().__init__() self.input_shape = input_shape self.norm_axes = norm_axes self.epsilon = epsilon self.beta = beta self.per_element_update = per_element_update self.running_mean = nn.Parameter(torch.zeros(input_shape, dtype=torch.float), requires_grad=False) self.running_mean_sq = nn.Parameter(torch.zeros(input_shape, dtype=torch.float), requires_grad=False) self.debiasing_term = nn.Parameter(torch.tensor(0.0, dtype=torch.float), requires_grad=False) def reset_parameters(self): self.running_mean.zero_() self.running_mean_sq.zero_() self.debiasing_term.zero_() def running_mean_var(self): debiased_mean = self.running_mean / self.debiasing_term.clamp(min=self.epsilon) debiased_mean_sq = self.running_mean_sq / self.debiasing_term.clamp(min=self.epsilon) debiased_var = (debiased_mean_sq - debiased_mean ** 2).clamp(min=1e-2) return debiased_mean, debiased_var def forward(self, input_vector): # Make sure input is float32 input_vector = input_vector.to(torch.float) if self.training: # Detach input before adding it to running means to avoid backpropping through it on # subsequent batches. detached_input = input_vector.detach() batch_mean = detached_input.mean(dim=tuple(range(self.norm_axes))) batch_sq_mean = (detached_input ** 2).mean(dim=tuple(range(self.norm_axes))) if self.per_element_update: batch_size = np.prod(detached_input.size()[: self.norm_axes]) weight = self.beta ** batch_size else: weight = self.beta self.running_mean.mul_(weight).add_(batch_mean * (1.0 - weight)) self.running_mean_sq.mul_(weight).add_(batch_sq_mean * (1.0 - weight)) self.debiasing_term.mul_(weight).add_(1.0 * (1.0 - weight)) mean, var = self.running_mean_var() return (input_vector - mean[(None,) * self.norm_axes]) / torch.sqrt(var)[(None,) * self.norm_axes] def denormalize(self, input_vector): """Transform normalized data back into original distribution""" mean, var = self.running_mean_var() return input_vector * torch.sqrt(var)[(None,) * self.norm_axes] + mean[(None,) * self.norm_axes] ================================================ FILE: metrics/IDM/lib/policy.py ================================================ from copy import deepcopy from email import policy from typing import Dict, Optional import numpy as np import torch as th from gym3.types import DictType from torch import nn from torch.nn import functional as F from lib.action_head import make_action_head from lib.action_mapping import CameraHierarchicalMapping from lib.impala_cnn import ImpalaCNN from lib.normalize_ewma import NormalizeEwma from lib.scaled_mse_head import ScaledMSEHead from lib.tree_util import tree_map from lib.util import FanInInitReLULayer, ResidualRecurrentBlocks from lib.misc import transpose class ImgPreprocessing(nn.Module): """Normalize incoming images. :param img_statistics: remote path to npz file with a mean and std image. If specified normalize images using this. :param scale_img: If true and img_statistics not specified, scale incoming images by 1/255. """ def __init__(self, img_statistics: Optional[str] = None, scale_img: bool = True): super().__init__() self.img_mean = None if img_statistics is not None: img_statistics = dict(**np.load(img_statistics)) self.img_mean = nn.Parameter(th.Tensor(img_statistics["mean"]), requires_grad=False) self.img_std = nn.Parameter(th.Tensor(img_statistics["std"]), requires_grad=False) else: self.ob_scale = 255.0 if scale_img else 1.0 def forward(self, img): x = img.to(dtype=th.float32) if self.img_mean is not None: x = (x - self.img_mean) / self.img_std else: x = x / self.ob_scale return x class ImgObsProcess(nn.Module): """ImpalaCNN followed by a linear layer. :param cnn_outsize: impala output dimension :param output_size: output size of the linear layer. :param dense_init_norm_kwargs: kwargs for linear FanInInitReLULayer :param init_norm_kwargs: kwargs for 2d and 3d conv FanInInitReLULayer """ def __init__( self, cnn_outsize: int, output_size: int, dense_init_norm_kwargs: Dict = {}, init_norm_kwargs: Dict = {}, **kwargs, ): super().__init__() self.cnn = ImpalaCNN( outsize=cnn_outsize, init_norm_kwargs=init_norm_kwargs, dense_init_norm_kwargs=dense_init_norm_kwargs, **kwargs, ) self.linear = FanInInitReLULayer( cnn_outsize, output_size, layer_type="linear", **dense_init_norm_kwargs, ) def forward(self, img): return self.linear(self.cnn(img)) class MinecraftPolicy(nn.Module): """ :param recurrence_type: None - No recurrence, adds no extra layers lstm - (Depreciated). Singular LSTM multi_layer_lstm - Multi-layer LSTM. Uses n_recurrence_layers to determine number of consecututive LSTMs Does NOT support ragged batching multi_masked_lstm - Multi-layer LSTM that supports ragged batching via the first vector. This model is slower Uses n_recurrence_layers to determine number of consecututive LSTMs transformer - Dense transformer :param init_norm_kwargs: kwargs for all FanInInitReLULayers. """ def __init__( self, recurrence_type="lstm", impala_width=1, impala_chans=(16, 32, 32), obs_processing_width=256, hidsize=512, single_output=False, # True if we don't need separate outputs for action/value outputs img_shape=None, scale_input_img=True, only_img_input=False, init_norm_kwargs={}, impala_kwargs={}, # Unused argument assumed by forc. input_shape=None, # pylint: disable=unused-argument active_reward_monitors=None, img_statistics=None, first_conv_norm=False, diff_mlp_embedding=False, attention_mask_style="clipped_causal", attention_heads=8, attention_memory_size=2048, use_pointwise_layer=True, pointwise_ratio=4, pointwise_use_activation=False, n_recurrence_layers=1, recurrence_is_residual=True, timesteps=None, use_pre_lstm_ln=True, # Not needed for transformer **unused_kwargs, ): super().__init__() assert recurrence_type in [ "multi_layer_lstm", "multi_layer_bilstm", "multi_masked_lstm", "transformer", "none", ] active_reward_monitors = active_reward_monitors or {} self.single_output = single_output chans = tuple(int(impala_width * c) for c in impala_chans) self.hidsize = hidsize # Dense init kwargs replaces batchnorm/groupnorm with layernorm self.init_norm_kwargs = init_norm_kwargs self.dense_init_norm_kwargs = deepcopy(init_norm_kwargs) if self.dense_init_norm_kwargs.get("group_norm_groups", None) is not None: self.dense_init_norm_kwargs.pop("group_norm_groups", None) self.dense_init_norm_kwargs["layer_norm"] = True if self.dense_init_norm_kwargs.get("batch_norm", False): self.dense_init_norm_kwargs.pop("batch_norm", False) self.dense_init_norm_kwargs["layer_norm"] = True # Setup inputs self.img_preprocess = ImgPreprocessing(img_statistics=img_statistics, scale_img=scale_input_img) self.img_process = ImgObsProcess( cnn_outsize=256, output_size=hidsize, inshape=img_shape, chans=chans, nblock=2, dense_init_norm_kwargs=self.dense_init_norm_kwargs, init_norm_kwargs=init_norm_kwargs, first_conv_norm=first_conv_norm, **impala_kwargs, ) self.pre_lstm_ln = nn.LayerNorm(hidsize) if use_pre_lstm_ln else None self.diff_obs_process = None self.recurrence_type = recurrence_type self.recurrent_layer = None self.recurrent_layer = ResidualRecurrentBlocks( hidsize=hidsize, timesteps=timesteps, recurrence_type=recurrence_type, is_residual=recurrence_is_residual, use_pointwise_layer=use_pointwise_layer, pointwise_ratio=pointwise_ratio, pointwise_use_activation=pointwise_use_activation, attention_mask_style=attention_mask_style, attention_heads=attention_heads, attention_memory_size=attention_memory_size, n_block=n_recurrence_layers, ) self.lastlayer = FanInInitReLULayer(hidsize, hidsize, layer_type="linear", **self.dense_init_norm_kwargs) self.final_ln = th.nn.LayerNorm(hidsize) def output_latent_size(self): return self.hidsize def forward(self, ob, state_in, context): first = context["first"] x = self.img_preprocess(ob["img"]) x = self.img_process(x) if self.diff_obs_process: processed_obs = self.diff_obs_process(ob["diff_goal"]) x = processed_obs + x if self.pre_lstm_ln is not None: x = self.pre_lstm_ln(x) if self.recurrent_layer is not None: x, state_out = self.recurrent_layer(x, first, state_in) else: state_out = state_in x = F.relu(x, inplace=False) x = self.lastlayer(x) x = self.final_ln(x) pi_latent = vf_latent = x if self.single_output: return pi_latent, state_out return (pi_latent, vf_latent), state_out def initial_state(self, batchsize): if self.recurrent_layer: return self.recurrent_layer.initial_state(batchsize) else: return None class MinecraftAgentPolicy(nn.Module): def __init__(self, action_space, policy_kwargs, pi_head_kwargs): super().__init__() self.net = MinecraftPolicy(**policy_kwargs) self.action_space = action_space self.value_head = self.make_value_head(self.net.output_latent_size()) self.pi_head = self.make_action_head(self.net.output_latent_size(), **pi_head_kwargs) def make_value_head(self, v_out_size: int, norm_type: str = "ewma", norm_kwargs: Optional[Dict] = None): return ScaledMSEHead(v_out_size, 1, norm_type=norm_type, norm_kwargs=norm_kwargs) def make_action_head(self, pi_out_size: int, **pi_head_opts): return make_action_head(self.action_space, pi_out_size, **pi_head_opts) def initial_state(self, batch_size: int): return self.net.initial_state(batch_size) def reset_parameters(self): super().reset_parameters() self.net.reset_parameters() self.pi_head.reset_parameters() self.value_head.reset_parameters() def forward(self, obs, first: th.Tensor, state_in): if isinstance(obs, dict): # We don't want to mutate the obs input. obs = obs.copy() # If special "mask" key is in obs, # It's for masking the logits. # We take it out (the network doesn't need it) mask = obs.pop("mask", None) else: mask = None (pi_h, v_h), state_out = self.net(obs, state_in, context={"first": first}) pi_logits = self.pi_head(pi_h, mask=mask) vpred = self.value_head(v_h) return (pi_logits, vpred, None), state_out def get_logprob_of_action(self, pd, action): """ Get logprob of taking action `action` given probability distribution (see `get_gradient_for_action` to get this distribution) """ ac = tree_map(lambda x: x.unsqueeze(1), action) log_prob = self.pi_head.logprob(ac, pd) assert not th.isnan(log_prob).any() return log_prob[:, 0] def get_kl_of_action_dists(self, pd1, pd2): """ Get the KL divergence between two action probability distributions """ return self.pi_head.kl_divergence(pd1, pd2) def get_output_for_observation(self, obs, state_in, first): """ Return gradient-enabled outputs for given observation. Use `get_logprob_of_action` to get log probability of action with the given probability distribution. Returns: - probability distribution given observation - value prediction for given observation - new state """ # We need to add a fictitious time dimension everywhere obs = tree_map(lambda x: x.unsqueeze(1), obs) first = first.unsqueeze(1) (pd, vpred, _), state_out = self(obs=obs, first=first, state_in=state_in) return pd, self.value_head.denormalize(vpred)[:, 0], state_out @th.no_grad() def act(self, obs, first, state_in, stochastic: bool = True, taken_action=None, return_pd=False): # We need to add a fictitious time dimension everywhere obs = tree_map(lambda x: x.unsqueeze(1), obs) first = first.unsqueeze(1) (pd, vpred, _), state_out = self(obs=obs, first=first, state_in=state_in) if taken_action is None: ac = self.pi_head.sample(pd, deterministic=not stochastic) else: ac = tree_map(lambda x: x.unsqueeze(1), taken_action) log_prob = self.pi_head.logprob(ac, pd) assert not th.isnan(log_prob).any() # After unsqueezing, squeeze back to remove fictitious time dimension result = {"log_prob": log_prob[:, 0], "vpred": self.value_head.denormalize(vpred)[:, 0]} if return_pd: result["pd"] = tree_map(lambda x: x[:, 0], pd) ac = tree_map(lambda x: x[:, 0], ac) return ac, state_out, result @th.no_grad() def v(self, obs, first, state_in): """Predict value for a given mdp observation""" obs = tree_map(lambda x: x.unsqueeze(1), obs) first = first.unsqueeze(1) (pd, vpred, _), state_out = self(obs=obs, first=first, state_in=state_in) # After unsqueezing, squeeze back return self.value_head.denormalize(vpred)[:, 0] class InverseActionNet(MinecraftPolicy): """ Args: conv3d_params: PRE impala 3D CNN params. They are just passed into th.nn.Conv3D. """ def __init__( self, hidsize=512, conv3d_params=None, **MCPoliy_kwargs, ): super().__init__( hidsize=hidsize, # If we're using 3dconv, then we normalize entire impala otherwise don't # normalize the first impala layer since we normalize the input first_conv_norm=conv3d_params is not None, **MCPoliy_kwargs, ) self.conv3d_layer = None if conv3d_params is not None: # 3D conv is the first layer, so don't normalize its input conv3d_init_params = deepcopy(self.init_norm_kwargs) conv3d_init_params["group_norm_groups"] = None conv3d_init_params["batch_norm"] = False self.conv3d_layer = FanInInitReLULayer( layer_type="conv3d", log_scope="3d_conv", **conv3d_params, **conv3d_init_params, ) def forward(self, ob, state_in, context): first = context["first"] x = self.img_preprocess(ob["img"]) # Conv3D Prior to Impala if self.conv3d_layer is not None: x = self._conv3d_forward(x) # Impala Stack x = self.img_process(x) if self.recurrent_layer is not None: x, state_out = self.recurrent_layer(x, first, state_in) x = F.relu(x, inplace=False) pi_latent = self.lastlayer(x) pi_latent = self.final_ln(x) return (pi_latent, None), state_out def _conv3d_forward(self, x): # Convert from (B, T, H, W, C) -> (B, H, W, C, T) x = transpose(x, "bthwc", "bcthw") new_x = [] for mini_batch in th.split(x, 1): new_x.append(self.conv3d_layer(mini_batch)) x = th.cat(new_x) # Convert back x = transpose(x, "bcthw", "bthwc") return x class InverseActionPolicy(nn.Module): def __init__( self, action_space, pi_head_kwargs=None, idm_net_kwargs=None, ): super().__init__() self.action_space = action_space self.net = InverseActionNet(**idm_net_kwargs) pi_out_size = self.net.output_latent_size() pi_head_kwargs = {} if pi_head_kwargs is None else pi_head_kwargs self.pi_head = self.make_action_head(pi_out_size=pi_out_size, **pi_head_kwargs) def make_action_head(self, **kwargs): return make_action_head(self.action_space, **kwargs) def reset_parameters(self): super().reset_parameters() self.net.reset_parameters() self.pi_head.reset_parameters() def forward(self, obs, first: th.Tensor, state_in, **kwargs): if isinstance(obs, dict): # We don't want to mutate the obs input. obs = obs.copy() # If special "mask" key is in obs, # It's for masking the logits. # We take it out (the network doesn't need it) mask = obs.pop("mask", None) else: mask = None (pi_h, _), state_out = self.net(obs, state_in=state_in, context={"first": first}, **kwargs) pi_logits = self.pi_head(pi_h, mask=mask) return (pi_logits, None, None), state_out @th.no_grad() def predict( self, obs, deterministic: bool = True, **kwargs, ): (pd, _, _), state_out = self(obs=obs, **kwargs) ac = self.pi_head.sample(pd, deterministic=deterministic) log_prob = self.pi_head.logprob(ac, pd) assert not th.isnan(log_prob).any() result = {"log_prob": log_prob, "pd": pd} return ac, state_out, result def initial_state(self, batch_size: int): return self.net.initial_state(batch_size) ================================================ FILE: metrics/IDM/lib/scaled_mse_head.py ================================================ from typing import Dict, Optional import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init from lib.action_head import fan_in_linear from lib.normalize_ewma import NormalizeEwma class ScaledMSEHead(nn.Module): """ Linear output layer that scales itself so that targets are always normalized to N(0, 1) """ def __init__( self, input_size: int, output_size: int, norm_type: Optional[str] = "ewma", norm_kwargs: Optional[Dict] = None ): super().__init__() self.input_size = input_size self.output_size = output_size self.norm_type = norm_type self.linear = nn.Linear(self.input_size, self.output_size) norm_kwargs = {} if norm_kwargs is None else norm_kwargs self.normalizer = NormalizeEwma(output_size, **norm_kwargs) def reset_parameters(self): init.orthogonal_(self.linear.weight) fan_in_linear(self.linear) self.normalizer.reset_parameters() def forward(self, input_data): return self.linear(input_data) def loss(self, prediction, target): """ Calculate the MSE loss between output and a target. 'Prediction' has to be normalized while target is denormalized. Loss is calculated in a 'normalized' space. """ return F.mse_loss(prediction, self.normalizer(target), reduction="mean") def denormalize(self, input_data): """Convert input value from a normalized space into the original one""" return self.normalizer.denormalize(input_data) def normalize(self, input_data): return self.normalizer(input_data) ================================================ FILE: metrics/IDM/lib/torch_util.py ================================================ import functools import itertools import math import os import pickle import re import subprocess import tempfile from contextlib import contextmanager from hashlib import md5, sha1 import numpy as np import torch as th import torch.distributed as dist import torch.distributions as dis import torch.nn.functional as F from torch import nn import lib.tree_util as tree_util from lib import misc def contextmanager_to_decorator(cm): def decorator(fn): @functools.wraps(fn) def newfn(*args, **kwargs): with cm(): return fn(*args, **kwargs) return newfn return decorator def have_cuda(): return th.has_cuda def default_device_type(): return "cuda" if have_cuda() else "cpu" no_grad = contextmanager_to_decorator(th.no_grad) DEFAULT_DEVICE = th.device(type=default_device_type()) def set_default_torch_device(device): global DEFAULT_DEVICE DEFAULT_DEVICE = th.device(device) def dev(): return DEFAULT_DEVICE def zeros(*args, **kwargs): return th.zeros(*args, **kwargs, device=dev()) def ones(*args, **kwargs): return th.ones(*args, **kwargs, device=dev()) def arange(*args, **kwargs): return th.arange(*args, **kwargs, device=dev()) def NormedLinear(*args, scale=1.0, dtype=th.float32, **kwargs): """ nn.Linear but with normalized fan-in init """ dtype = parse_dtype(dtype) if dtype == th.float32: out = nn.Linear(*args, **kwargs) elif dtype == th.float16: out = LinearF16(*args, **kwargs) else: raise ValueError(dtype) out.weight.data *= scale / out.weight.norm(dim=1, p=2, keepdim=True) if kwargs.get("bias", True): out.bias.data *= 0 return out class LinearF16(nn.Linear): def forward(self, x): return F.linear(x, self.weight.half(), self.bias.half() if self.bias is not None else None) class LayerNormF16(nn.LayerNorm): def forward(self, x): return F.layer_norm(x, self.normalized_shape, self.weight.half(), self.bias.half(), self.eps) def LayerNorm(*args, dtype=th.float32, **kwargs): dtype = parse_dtype(dtype) if dtype == th.float32: out = nn.LayerNorm(*args, **kwargs) elif dtype == th.float16: out = LayerNormF16(*args, **kwargs) else: raise ValueError(dtype) out.weight.no_scale = True return out def flatten_image(x): """ Flattens last three dims """ *batch_shape, h, w, c = x.shape return x.reshape((*batch_shape, h * w * c)) def sequential(layers, x, *args, diag_name=None, use_checkpoint=False): for (i, layer) in enumerate(layers): x = layer(x, *args) return x @no_grad def load_average_with_metadata(paths, overrides): n_models = len(paths) model, metadata = load_with_metadata(paths[0], overrides=overrides) for p in model.parameters(): p.mul_(1 / n_models) for p in paths[1:]: new_model, _ = load_with_metadata(p, overrides=overrides) for (n1, p1), (n2, p2) in misc.safezip(model.named_parameters(), new_model.named_parameters()): assert n1 == n2, f"names {n1} and {n2} don't match" p1.add_(p2.mul_(1 / n_models)) return model, metadata def save_kwargs(fn): """ This decorator passes through the user-provided kwargs and adds one more, called save_kwargs, mapping to {"create_fn" : name_of_decorated_fn, "kwargs" : other_kwargs} You put on this decorator on a function that creates a pytorch module. This will save the kwargs and the function that was used to create the module. This lets us restore the model state later. """ @functools.wraps(fn) def wrapper(**kwargs): if "save_kwargs" in kwargs: return fn(**kwargs) else: sk = {**kwargs, "create_fn": f"{fn.__module__}:{fn.__name__}"} return fn(save_kwargs=sk, **kwargs) return wrapper def parse_dtype(x): if isinstance(x, th.dtype): return x elif isinstance(x, str): if x == "float32" or x == "float": return th.float32 elif x == "float64" or x == "double": return th.float64 elif x == "float16" or x == "half": return th.float16 elif x == "uint8": return th.uint8 elif x == "int8": return th.int8 elif x == "int16" or x == "short": return th.int16 elif x == "int32" or x == "int": return th.int32 elif x == "int64" or x == "long": return th.int64 elif x == "bool": return th.bool else: raise ValueError(f"cannot parse {x} as a dtype") else: raise TypeError(f"cannot parse {type(x)} as dtype") def index(x, i): """ Batched, broadcasting index of x along dimension i.ndim. For example, if x has shape (1, 2, 3, 4, 5) and i has shape (1, 1, 3) then the result has shape (1, 2, 3, 5) and each value in i must be between 0 and 3. """ assert x.ndim >= i.ndim + 1 gather_dim = i.ndim while i.ndim < x.ndim: i = i.unsqueeze(-1) expand_shape = list(x.shape) expand_shape[gather_dim] = 1 i = i.expand(*expand_shape) xi = th.gather(x, gather_dim, i) assert xi.shape[gather_dim] == 1 return xi.squeeze(gather_dim) ================================================ FILE: metrics/IDM/lib/tree_util.py ================================================ # Copyright 2018 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copied this from jax, made it self-contained # Currently just used for improved_checkpoint import collections import functools import itertools as it from collections.abc import Collection from typing import Dict, List, Optional def unzip2(xys): xs = [] ys = [] for x, y in xys: xs.append(x) ys.append(y) return tuple(xs), tuple(ys) def partial(fun, *args, **kwargs): wrapped = functools.partial(fun, *args, **kwargs) functools.update_wrapper(wrapped, fun) wrapped._bound_args = args # pylint: disable=protected-access return wrapped def safe_zip(*args: Collection) -> List[tuple]: n = len(args[0]) for arg in args[1:]: assert len(arg) == n, "length mismatch: {}".format(list(map(len, args))) return list(zip(*args)) def safe_map(f, *args): args = list(map(list, args)) n = len(args[0]) for arg in args[1:]: assert len(arg) == n, "length mismatch: {}".format(list(map(len, args))) return list(map(f, *args)) def tree_map(f, tree, treat_as_leaves: Optional[List] = None): """Map a function over a pytree to produce a new pytree. Args: f: function to be applied at each leaf. tree: a pytree to be mapped over. Returns: A new pytree with the same structure as `tree` but with the value at each leaf given by `f(x)` where `x` is the value at the corresponding leaf in `tree`. """ if treat_as_leaves is None: treat_as_leaves = [] node_type = node_types.get(type(tree)) if node_type and type(tree) not in treat_as_leaves: children, node_spec = node_type.to_iterable(tree) new_children = [tree_map(f, child, treat_as_leaves) for child in children] return node_type.from_iterable(node_spec, new_children) else: return f(tree) def tree_multimap(f, tree, *rest, treat_as_leaves: Optional[List] = None): """Map a multi-input function over pytree args to produce a new pytree. Args: f: function that takes `1 + len(rest)` arguments, to be applied at the corresponding leaves of the pytrees. tree: a pytree to be mapped over, with each leaf providing the first positional argument to `f`. *rest: a tuple of pytrees, each with the same structure as `tree`. Returns: A new pytree with the same structure as `tree` but with the value at each leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf in `tree` and `xs` is the tuple of values at corresponding leaves in `rest`. """ if treat_as_leaves is None: treat_as_leaves = [] node_type = node_types.get(type(tree)) if node_type and type(tree) not in treat_as_leaves: children, node_spec = node_type.to_iterable(tree) all_children = [children] for other_tree in rest: other_children, other_node_data = node_type.to_iterable(other_tree) if other_node_data != node_spec: raise TypeError("Mismatch: {} != {}".format(other_node_data, node_spec)) all_children.append(other_children) new_children = [tree_multimap(f, *xs, treat_as_leaves=treat_as_leaves) for xs in zip(*all_children)] return node_type.from_iterable(node_spec, new_children) else: return f(tree, *rest) def prefix_multimap(f, treedef, tree, *rest): """Like tree_multimap but only maps down through a tree prefix.""" if isinstance(treedef, PyLeaf): return f(tree, *rest) else: node_type = node_types.get(type(tree)) if node_type != treedef.node_type: raise TypeError("Mismatch: {} != {}".format(treedef.node_type, node_type)) children, node_data = node_type.to_iterable(tree) if node_data != treedef.node_data: raise TypeError("Mismatch: {} != {}".format(treedef.node_data, node_data)) all_children = [children] for other_tree in rest: other_children, other_node_data = node_type.to_iterable(other_tree) if other_node_data != node_data: raise TypeError("Mismatch: {} != {}".format(other_node_data, node_data)) all_children.append(other_children) all_children = zip(*all_children) new_children = [prefix_multimap(f, td, *xs) for td, xs in zip(treedef.children, all_children)] return node_type.from_iterable(node_data, new_children) def walk_pytree(f_node, f_leaf, tree, treat_as_leaves: Optional[List] = None): node_type = node_types.get(type(tree)) if treat_as_leaves is None: treat_as_leaves = [] if node_type and type(tree) not in treat_as_leaves: children, node_spec = node_type.to_iterable(tree) proc_children, child_specs = unzip2([walk_pytree(f_node, f_leaf, child, treat_as_leaves) for child in children]) tree_def = PyTreeDef(node_type, node_spec, child_specs) return f_node(proc_children), tree_def else: return f_leaf(tree), PyLeaf() def build_tree(treedef, xs): if isinstance(treedef, PyLeaf): return xs else: # We use 'iter' for clearer error messages children = safe_map(build_tree, iter(treedef.children), iter(xs)) return treedef.node_type.from_iterable(treedef.node_data, children) def _tree_unflatten(xs, treedef): if isinstance(treedef, PyLeaf): return next(xs) else: children = safe_map(partial(_tree_unflatten, xs), treedef.children) return treedef.node_type.from_iterable(treedef.node_data, children) def _num_leaves(treedef): return 1 if isinstance(treedef, PyLeaf) else sum(safe_map(_num_leaves, treedef.children)) def _nested_treedef(inner, outer): # just used in tree_transpose error checking if isinstance(outer, PyLeaf): return inner else: children = safe_map(partial(_nested_treedef, inner), outer.children) return PyTreeDef(outer.node_type, outer.node_data, tuple(children)) class PyTreeDef(object): def __init__(self, node_type, node_data, children): self.node_type = node_type self.node_data = node_data self.children = children def __repr__(self): if self.node_data is None: data_repr = "" else: data_repr = "[{}]".format(self.node_data) return "PyTree({}{}, [{}])".format(self.node_type.name, data_repr, ",".join(safe_map(repr, self.children))) def __hash__(self): return hash((self.node_type, self.node_data, tuple(self.children))) def __eq__(self, other): if isinstance(other, PyLeaf): return False else: return self.node_type == other.node_type and self.node_data == other.node_data and self.children == other.children def __ne__(self, other): return not self == other class PyLeaf(object): def __repr__(self): return "*" def __eq__(self, other): return isinstance(other, PyLeaf) class NodeType(object): def __init__(self, name, to_iterable, from_iterable): self.name = name self.to_iterable = to_iterable self.from_iterable = from_iterable node_types: Dict[type, NodeType] = {} def register_pytree_node(py_type, to_iterable, from_iterable): assert py_type not in node_types node_types[py_type] = NodeType(str(py_type), to_iterable, from_iterable) def tuple_to_iterable(xs): return xs, None def tuple_from_iterable(_keys, xs): return tuple(xs) def list_to_iterable(xs): return tuple(xs), None def list_from_iterable(_keys, xs): return list(xs) def dict_to_iterable(xs): keys = tuple(sorted(xs.keys())) return tuple(map(xs.get, keys)), keys def dict_from_iterable(keys, xs): return dict(safe_zip(keys, xs)) def ordered_dict_from_iterable(keys, xs): return collections.OrderedDict(safe_zip(keys, xs)) def default_dict_to_iterable(xs): return (tuple(xs.values()), (xs.default_factory, tuple(xs.keys()))) def default_dict_from_iterable(keys, xs): return collections.defaultdict(keys[0], safe_zip(keys[1], xs)) def none_to_iterable(_xs): return (), None def none_from_iterable(_keys, _xs): return None register_pytree_node(tuple, tuple_to_iterable, tuple_from_iterable) register_pytree_node(list, list_to_iterable, list_from_iterable) register_pytree_node(dict, dict_to_iterable, dict_from_iterable) register_pytree_node(collections.OrderedDict, dict_to_iterable, ordered_dict_from_iterable) register_pytree_node(collections.defaultdict, default_dict_to_iterable, default_dict_from_iterable) register_pytree_node(type(None), none_to_iterable, none_from_iterable) ================================================ FILE: metrics/IDM/lib/util.py ================================================ from typing import Dict, Optional import torch as th from torch import nn from torch.nn import functional as F import lib.torch_util as tu from lib.masked_attention import MaskedAttention from lib.minecraft_util import store_args from lib.tree_util import tree_map def get_module_log_keys_recursive(m: nn.Module): """Recursively get all keys that a module and its children want to log.""" keys = [] if hasattr(m, "get_log_keys"): keys += m.get_log_keys() for c in m.children(): keys += get_module_log_keys_recursive(c) return keys class FanInInitReLULayer(nn.Module): """Implements a slightly modified init that correctly produces std 1 outputs given ReLU activation :param inchan: number of input channels :param outchan: number of output channels :param layer_args: positional layer args :param layer_type: options are "linear" (dense layer), "conv" (2D Convolution), "conv3d" (3D convolution) :param init_scale: multiplier on initial weights :param batch_norm: use batch norm after the layer (for 2D data) :param group_norm_groups: if not None, use group norm with this many groups after the layer. Group norm 1 would be equivalent of layernorm for 2D data. :param layer_norm: use layernorm after the layer (for 1D data) :param layer_kwargs: keyword arguments for the layer """ @store_args def __init__( self, inchan: int, outchan: int, *layer_args, layer_type: str = "conv", init_scale: int = 1, batch_norm: bool = False, batch_norm_kwargs: Dict = {}, group_norm_groups: Optional[int] = None, layer_norm: bool = False, use_activation=True, log_scope: Optional[str] = None, **layer_kwargs, ): super().__init__() # Normalization self.norm = None if batch_norm: self.norm = nn.BatchNorm2d(inchan, **batch_norm_kwargs) elif group_norm_groups is not None: self.norm = nn.GroupNorm(group_norm_groups, inchan) elif layer_norm: self.norm = nn.LayerNorm(inchan) layer = dict(conv=nn.Conv2d, conv3d=nn.Conv3d, linear=nn.Linear)[layer_type] self.layer = layer(inchan, outchan, bias=self.norm is None, *layer_args, **layer_kwargs) # Init Weights (Fan-In) self.layer.weight.data *= init_scale / self.layer.weight.norm( dim=tuple(range(1, self.layer.weight.data.ndim)), p=2, keepdim=True ) # Init Bias if self.layer.bias is not None: self.layer.bias.data *= 0 def forward(self, x): """Norm after the activation. Experimented with this for both IAM and BC and it was slightly better.""" if self.norm is not None: x = self.norm(x) x = self.layer(x) if self.use_activation: x = F.relu(x, inplace=True) return x def get_log_keys(self): return [ f"activation_mean/{self.log_scope}", f"activation_std/{self.log_scope}", ] class ResidualRecurrentBlocks(nn.Module): @store_args def __init__( self, n_block=2, recurrence_type="multi_layer_lstm", is_residual=True, **block_kwargs, ): super().__init__() init_scale = n_block ** -0.5 if is_residual else 1 self.blocks = nn.ModuleList( [ ResidualRecurrentBlock( **block_kwargs, recurrence_type=recurrence_type, is_residual=is_residual, init_scale=init_scale, block_number=i, ) for i in range(n_block) ] ) def forward(self, x, first, state): state_out = [] assert len(state) == len( self.blocks ), f"Length of state {len(state)} did not match length of blocks {len(self.blocks)}" for block, _s_in in zip(self.blocks, state): x, _s_o = block(x, first, _s_in) state_out.append(_s_o) return x, state_out def initial_state(self, batchsize): if "lstm" in self.recurrence_type: return [None for b in self.blocks] else: return [b.r.initial_state(batchsize) for b in self.blocks] class ResidualRecurrentBlock(nn.Module): @store_args def __init__( self, hidsize, timesteps, init_scale=1, recurrence_type="multi_layer_lstm", is_residual=True, use_pointwise_layer=True, pointwise_ratio=4, pointwise_use_activation=False, attention_heads=8, attention_memory_size=2048, attention_mask_style="clipped_causal", log_scope="resblock", block_number=0, ): super().__init__() self.log_scope = f"{log_scope}{block_number}" s = init_scale if use_pointwise_layer: if is_residual: s *= 2 ** -0.5 # second residual self.mlp0 = FanInInitReLULayer( hidsize, hidsize * pointwise_ratio, init_scale=1, layer_type="linear", layer_norm=True, log_scope=self.log_scope + "/ptwise_mlp0", ) self.mlp1 = FanInInitReLULayer( hidsize * pointwise_ratio, hidsize, init_scale=s, layer_type="linear", use_activation=pointwise_use_activation, log_scope=self.log_scope + "/ptwise_mlp1", ) self.pre_r_ln = nn.LayerNorm(hidsize) if recurrence_type in ["multi_layer_lstm", "multi_layer_bilstm"]: self.r = nn.LSTM(hidsize, hidsize, batch_first=True) nn.init.normal_(self.r.weight_hh_l0, std=s * (self.r.weight_hh_l0.shape[0] ** -0.5)) nn.init.normal_(self.r.weight_ih_l0, std=s * (self.r.weight_ih_l0.shape[0] ** -0.5)) self.r.bias_hh_l0.data *= 0 self.r.bias_ih_l0.data *= 0 elif recurrence_type == "transformer": self.r = MaskedAttention( input_size=hidsize, timesteps=timesteps, memory_size=attention_memory_size, heads=attention_heads, init_scale=s, norm="none", log_scope=log_scope + "/sa", use_muP_factor=True, mask=attention_mask_style, ) def forward(self, x, first, state): residual = x x = self.pre_r_ln(x) x, state_out = recurrent_forward( self.r, x, first, state, reverse_lstm=self.recurrence_type == "multi_layer_bilstm" and (self.block_number + 1) % 2 == 0, ) if self.is_residual and "lstm" in self.recurrence_type: # Transformer already residual. x = x + residual if self.use_pointwise_layer: # Residual MLP residual = x x = self.mlp1(self.mlp0(x)) if self.is_residual: x = x + residual return x, state_out def recurrent_forward(module, x, first, state, reverse_lstm=False): if isinstance(module, nn.LSTM): if state is not None: # In case recurrent models do not accept a "first" argument we zero out the hidden state here mask = 1 - first[:, 0, None, None].to(th.float) state = tree_map(lambda _s: _s * mask, state) state = tree_map(lambda _s: _s.transpose(0, 1), state) # NL, B, H if reverse_lstm: x = th.flip(x, [1]) x, state_out = module(x, state) if reverse_lstm: x = th.flip(x, [1]) state_out = tree_map(lambda _s: _s.transpose(0, 1), state_out) # B, NL, H return x, state_out else: return module(x, first, state) def _banded_repeat(x, t): """ Repeats x with a shift. For example (ignoring the batch dimension): _banded_repeat([A B C D E], 4) = [D E 0 0 0] [C D E 0 0] [B C D E 0] [A B C D E] """ b, T = x.shape x = th.cat([x, x.new_zeros(b, t - 1)], dim=1) result = x.unfold(1, T, 1).flip(1) return result def bandify(b_nd, t, T): """ b_nd -> D_ntT, where "n" indexes over basis functions "d" indexes over time differences "t" indexes over output time "T" indexes over input time only t >= T is nonzero B_ntT[n, t, T] = b_nd[n, t - T] """ nbasis, bandsize = b_nd.shape b_nd = b_nd[:, th.arange(bandsize - 1, -1, -1)] if bandsize >= T: b_nT = b_nd[:, -T:] else: b_nT = th.cat([b_nd.new_zeros(nbasis, T - bandsize), b_nd], dim=1) D_tnT = _banded_repeat(b_nT, t) return D_tnT def get_norm(name, d, dtype=th.float32): if name == "none": return lambda x: x elif name == "layer": return tu.LayerNorm(d, dtype=dtype) else: raise NotImplementedError(name) ================================================ FILE: metrics/IDM/lib/xf.py ================================================ """ Implementation of transformer and reshaping-based sparse transformer """ import functools import math import torch as th from torch import nn from torch.nn import functional as F from lib import misc, mlp from lib import torch_util as tu from lib import util SENTINEL = 0.1337 def attention( Q_bte, K_bTe, V_bTe, dtype, mask=True, extra_btT=None, maxlen=None, check_sentinel=False, use_muP_factor=False, ): """ performs softmax(Q*K)*V operation t : output (write) time axis, possibly size=1 for just the last timestep T : input (read) time axis t < T is OK 'check_sentinel' is used when you want to make it impossible to attend to certain keys. All keys where every value is equal to the constant SENTINEL will be ignored. Currently this is only used by StridedAttn. """ assert Q_bte.dtype == K_bTe.dtype == dtype, f"{Q_bte.dtype}, {K_bTe.dtype}, {dtype} must all match" e = Q_bte.shape[2] if check_sentinel: invalid = (K_bTe == SENTINEL).int().sum(dim=-1) == e invalid = misc.reshape(invalid, "b, T", "b, 1, T") if isinstance(mask, th.Tensor): bias = (~mask).float() * -1e9 elif mask: bias = get_attn_bias_cached(Q_bte.shape[1], K_bTe.shape[1], maxlen=maxlen, device=Q_bte.device, dtype=th.float32) else: bias = Q_bte.new_zeros((), dtype=th.float32) if extra_btT is not None: bias = bias + extra_btT # Equivalent to bias + (1 / math.sqrt(e)) * th.einsum("bte,bpe->btp", Q_bte, K_bte) # but faster: logit_btT = th.baddbmm( bias, Q_bte.float(), K_bTe.float().transpose(-1, -2), alpha=(1 / e) if use_muP_factor else (1 / math.sqrt(e)), ) if check_sentinel: logit_btT = logit_btT - 1e9 * invalid.float() W_btT = th.softmax(logit_btT, dim=2).to(dtype) if callable(V_bTe): # This is used by the sharded video model to defer waiting on # the broadcast of the values until they're needed V_bTe = V_bTe() # th.einsum only lets you use lowercase letters, so 'p' for 'past' # means 'T' A_bte = th.einsum("btp,bpe->bte", W_btT, V_bTe) return A_bte class Attn: """ Defines an attention mechanism All the mechanisms here can be defined by two operations: 1. preprocessing Q,K,V,R[=relative attention query] to move axes from embedding dimension to batch dimension, and possibly doing shifts. 2. postprocessing the final result to move axes back to embedding axis. """ def __init__(self, mask, maxlen): self.mask = mask self.maxlen = maxlen def preproc_qkv(self, Q_bte, K_bte, V_bte): raise NotImplementedError def preproc_r(self, R_btn): raise NotImplementedError def split_heads(x_bte, h): b, t, e = x_bte.shape assert e % h == 0, "Embsize must be divisible by number of heads" q = e // h x_bthq = x_bte.reshape((b, t, h, q)) x_bhtq = misc.transpose(x_bthq, "bthq", "bhtq") x_Btq = x_bhtq.reshape((b * h, t, q)) return x_Btq class All2All(Attn): def __init__(self, nhead, maxlen, mask=True, head_dim=None): super().__init__(mask=mask, maxlen=maxlen) assert (nhead is None) != (head_dim is None), "exactly one of nhead and head_dim must be specified" self.h = nhead self.head_dim = head_dim def preproc_qkv(self, *xs): q = xs[0].shape[-1] for x in xs: assert x.shape[-1] == q, "embedding dimensions do not match" h = self.h or misc.exact_div(q, self.head_dim) postproc = functools.partial(self.postproc_a, h=h) return (postproc, *tuple(split_heads(x, h) for x in xs)) def preproc_r(self, R_btn): _, ret = self.preproc_qkv(R_btn) return ret def postproc_a(self, A_Btq, h): B, t, q = A_Btq.shape b = B // h A_bhtq = A_Btq.reshape((b, h, t, q)) A_bthq = misc.transpose(A_bhtq, "bhtq", "bthq") A_bte = A_bthq.reshape((b, t, h * q)) return A_bte def _required_padding(dim, target_div): if dim % target_div == 0: return 0 else: return target_div - dim % target_div class StridedAttn(Attn): def __init__(self, nhead, stride, maxlen, mask=True): super().__init__(mask=mask, maxlen=maxlen) self.h = nhead self.stride = stride def _preproc(self, x, name, Q_t=None, Q_pad=None): x, undo = misc.reshape_undo(x, "b, t*stride, e", "b, 1, t, stride*e", stride=self.stride) if name == "Q": Q_pad = _required_padding(x.shape[2], self.maxlen) original_t = x.shape[2] x = F.pad(x, (0, 0, 0, Q_pad), value=SENTINEL) undo = misc.compose_undo(undo, lambda x: x[:, :, :original_t]) if name == "Q": Q_t = x.shape[2] assert Q_t % self.maxlen == 0, f"{Q_t} % {self.maxlen} != 0" else: required_len = Q_t + self.maxlen if x.shape[2] < required_len: x = F.pad(x, (0, 0, required_len - x.shape[2], 0), value=SENTINEL) assert x.shape[2] >= required_len back = x[:, :, -Q_t - self.maxlen : -self.maxlen] front = x[:, :, -Q_t:] x = th.cat([back, front], dim=1) _, _, t, _ = x.shape assert t == Q_t, f"{t} != {Q_t}" x, undo = misc.reshape_undo( x, "b, pad_shift, t*maxlen, stride*h*q", "b, pad_shift, t, maxlen, stride, h, q", maxlen=self.maxlen, h=self.h, stride=self.stride, undo=undo, ) x, undo = misc.transpose_undo(x, "bptmshq", "bthspmq", undo=undo) x, undo = misc.reshape_undo( x, "b, t, h, stride, pad_shift, maxlen, q", "b*t*h*stride, pad_shift*maxlen, q", undo=undo, ) if name == "Q": return x, undo, Q_t, Q_pad else: return x def preproc_qkv(self, Q_bte, K_bte, V_bte): pad = _required_padding(Q_bte.shape[1], self.stride) if pad: Q_bte = F.pad(Q_bte, (0, 0, 0, pad), value=SENTINEL) K_bte = F.pad(K_bte, (0, 0, 0, pad), value=SENTINEL) if K_bte is not None else None V_bte = F.pad(V_bte, (0, 0, 0, pad), value=SENTINEL) if V_bte is not None else None undo = lambda x, pad=pad: x[:, :-pad] else: undo = None if K_bte is not None: pad = _required_padding(K_bte.shape[1], self.stride) if pad: K_bte = F.pad(K_bte, (0, 0, pad, 0), value=SENTINEL) V_bte = F.pad(V_bte, (0, 0, pad, 0), value=SENTINEL) assert Q_bte.shape[1] % self.stride == 0 assert K_bte is None or K_bte.shape[1] % self.stride == 0 assert V_bte is None or V_bte.shape[1] % self.stride == 0 Q, postproc, Q_t, Q_pad = self._preproc(Q_bte, "Q") postproc = misc.compose_undo(undo, postproc) return ( postproc, Q, self._preproc(K_bte, "K", Q_t=Q_t, Q_pad=Q_pad) if K_bte is not None else None, self._preproc(V_bte, "V", Q_t=Q_t, Q_pad=Q_pad) if V_bte is not None else None, ) def preproc_r(self, R_bte): _, R, _, _ = self.preproc_qkv(R_bte, None, None) return R Q_SCALE = 0.1 K_SCALE = 0.2 V_SCALE = 1.0 PROJ_SCALE = 1.0 MLP0_SCALE = 1.0 MLP1_SCALE = 1.0 R_SCALE = 0.1 B_SCALE = 0.2 class AttentionLayerBase(nn.Module): def __init__( self, *, attn, scale, x_size, c_size, qk_size, v_size, dtype, relattn=False, seqlens=None, separate=False, ): super().__init__() dtype = tu.parse_dtype(dtype) self.attn = attn self.x_size = x_size self.c_size = c_size s = math.sqrt(scale) separgs = dict(seqlens=seqlens, separate=separate) self.q_layer = MultiscaleLinear(x_size, qk_size, name="q", scale=Q_SCALE, dtype=dtype, **separgs) self.k_layer = MultiscaleLinear(c_size, qk_size, name="k", scale=K_SCALE, bias=False, dtype=dtype, **separgs) self.v_layer = MultiscaleLinear(c_size, v_size, name="v", scale=V_SCALE * s, bias=False, dtype=dtype, **separgs) self.proj_layer = MultiscaleLinear(v_size, x_size, name="proj", scale=PROJ_SCALE * s, dtype=dtype, **separgs) self.relattn = relattn maxlen = attn.maxlen assert maxlen > 0 or not attn.mask if self.relattn: nbasis = 10 self.r_layer = tu.NormedLinear(x_size, nbasis * attn.h, scale=R_SCALE, dtype=dtype) self.b_nd = nn.Parameter(th.randn(nbasis, maxlen) * B_SCALE) self.maxlen = maxlen self.dtype = dtype def relattn_logits(self, X_bte, T): R_btn = self.r_layer(X_bte).float() R_btn = self.attn.preproc_r(R_btn) t = R_btn.shape[1] D_ntT = util.bandify(self.b_nd, t, T) extra_btT = th.einsum("btn,ntp->btp", R_btn, D_ntT) return extra_btT def quick_gelu(x): return x * th.sigmoid(1.702 * x) def act(actname, x): if actname == "relu": return F.relu(x) elif actname == "gelu": return quick_gelu(x) elif actname == "none": return x else: raise NotImplementedError(actname) class SelfAttentionLayer(AttentionLayerBase): """ Residual attention layer that takes a single tensor x and has it attend to itself Has the form output = x + f(x) """ def __init__( self, x_size, attn, scale, dtype="float32", norm="layer", cache_keep_len=None, relattn=False, log_scope="sa", use_muP_factor=False, **kwargs, ): super().__init__( x_size=x_size, c_size=x_size, qk_size=x_size, v_size=x_size, attn=attn, scale=scale, relattn=relattn, dtype=dtype, **kwargs, ) self.ln_x = util.get_norm(norm, x_size, dtype=dtype) if cache_keep_len is None: if hasattr(attn, "cache_keep_len"): cache_keep_len = attn.cache_keep_len else: if isinstance(attn, StridedAttn): stride = attn.stride else: stride = 1 cache_keep_len = stride * attn.maxlen self.cache_keep_len = cache_keep_len self.log_scope = log_scope self.use_muP_factor = use_muP_factor def residual(self, X_bte, state): X_bte = self.ln_x(X_bte) Q_bte = self.q_layer(X_bte) K_bte = self.k_layer(X_bte) V_bte = self.v_layer(X_bte) if state: state, K_bte, V_bte = self.update_state(state, K_bte, V_bte) postproc_closure, Q_bte, K_bte, V_bte = self.attn.preproc_qkv(Q_bte, K_bte, V_bte) extra_btT = self.relattn_logits(X_bte, K_bte.shape[1]) if self.relattn else None A_bte = attention( Q_bte, K_bte, V_bte, mask=self.attn.mask, extra_btT=extra_btT, maxlen=self.maxlen, dtype=self.dtype, check_sentinel=isinstance(self.attn, StridedAttn), use_muP_factor=self.use_muP_factor, ) A_bte = postproc_closure(A_bte) Aproj_bte = self.proj_layer(A_bte) return Aproj_bte, state def forward(self, X_bte, state): R_bte, state = self.residual(X_bte, state) return X_bte + R_bte, state def stateless_forward(self, X_bte): out_bte, _state = self.forward(X_bte, None) return out_bte def update_state(self, state, K_bte, V_bte): def append(prev, new): """ Given `prev` keys from cache, and `new` keys, returns (cache, full), where - cache goes into the output state, length chosen so that on the next timestep, there are enough cached timesteps to get the full context of lenth self.maxlen. - full is used for the current forward pass, with length chosen so that the first timestep new[:, 0] gets to see a context of self.maxlen. """ tprev = prev.shape[1] startfull = max(tprev - self.cache_keep_len, 0) full = th.cat([prev[:, startfull:], new], dim=1) outstate = full[:, max(full.shape[1] - (self.cache_keep_len), 0) :] # To see that the preceding slicing is correct, consider the case # that maxlen==1. Then `full` only consists of `new`, and # `outstate` is empty return outstate, full instate_K, instate_V = state outstate_K, K_bte = append(instate_K, K_bte) outstate_V, V_bte = append(instate_V, V_bte) assert outstate_K.shape[-2] <= self.cache_keep_len return (outstate_K, outstate_V), K_bte, V_bte def initial_state(self, batchsize, initial_T=0): return ( tu.zeros((batchsize, initial_T, self.x_size), dtype=self.dtype), tu.zeros((batchsize, initial_T, self.x_size), dtype=self.dtype), ) def empty_state(self): return None class PointwiseLayer(nn.Module): """ Residual MLP applied at each timestep """ def __init__(self, x_size, scale, dtype, norm, actname="relu", mlp_ratio=2): super().__init__() s = math.sqrt(scale) self.ln = util.get_norm(norm, x_size, dtype=dtype) self.mlp = mlp.MLP( insize=x_size, nhidlayer=1, outsize=x_size, hidsize=int(x_size * mlp_ratio), hidactiv=functools.partial(act, actname), dtype=dtype, ) self.mlp.layers[0].weight.data *= MLP0_SCALE * s self.mlp.layers[1].weight.data *= MLP1_SCALE * s def residual(self, x): x = self.ln(x) x = self.mlp(x) return x def forward(self, x): return x + self.residual(x) def _is_separate(sep, name): if isinstance(sep, bool): return sep assert isinstance(sep, set) if name in sep: sep.remove(name) return True else: return False def make_maybe_multiscale(make_fn, *args, seqlens, separate, name, **kwargs): """ This function either creates one instance of a module or creates a separate instance of the module for each resolution of the image, determined by the `separate` parameter. We create separate modules if `separate` is True or if `separate` is a set containing `name`. """ if _is_separate(separate, name): modules = [make_fn(*args, **kwargs) for _ in seqlens] return SplitCallJoin(modules, seqlens) else: return make_fn(*args, **kwargs) class SplitCallJoin(nn.Module): def __init__(self, mods, seqlens): super().__init__() self.mods = nn.ModuleList(mods) self.seqlens = seqlens def forward(self, x): tl = sum(self.seqlens) x, undo = misc.reshape_undo(x, "..., z*tl, e", "..., z, tl, e", tl=tl) x = list(th.split(x, self.seqlens, dim=-2)) new_x = [] for x, mod in misc.safezip(x, self.mods): x, this_undo = misc.reshape_undo(x, "..., z, l, e", "..., z*l, e") x = mod(x) x = this_undo(x) new_x.append(x) x = th.cat(new_x, dim=-2) x = undo(x) return x MultiscaleLinear = functools.partial(make_maybe_multiscale, tu.NormedLinear) MultiscalePointwise = functools.partial(make_maybe_multiscale, PointwiseLayer) ================================================ FILE: metrics/common_metrics.py ================================================ import sys import os sys.path.append(os.getcwd()) from common_metrics_on_video_quality.calculate_fvd import calculate_fvd from common_metrics_on_video_quality.calculate_lpips import calculate_lpips from common_metrics_on_video_quality.calculate_ssim import calculate_ssim from common_metrics_on_video_quality.calculate_psnr import calculate_psnr import os import cv2 import torch import numpy as np import argparse import json device = torch.device("cuda") def load_videos_to_tensor(video_dir, number_of_videos, video_length, channel, size,video_files=None): videos_tensor = torch.zeros(number_of_videos, video_length, channel, size[0], size[1], requires_grad=False) if video_files is None: video_files = [f for f in os.listdir(video_dir) if f.endswith(('.mp4'))] video_files = sorted(video_files, key=lambda x: int(x.split("_")[-1].split(".")[0])) video_files = video_files[:number_of_videos] for i, video_file in enumerate(video_files): video_path = os.path.join(video_dir, video_file) cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print(f"Failed to open video: {video_path}") continue frames = [] # get video total length ; our gt has 16 frame but we only use 15 frame so set video_length to 15 and start frame to 1 real_video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if real_video_length > video_length: # set start frame to video_length - video_length start_frame = real_video_length - video_length cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) else: # set start frame to 0 cap.set(cv2.CAP_PROP_POS_FRAMES, 0) while len(frames) < video_length: ret, frame = cap.read() if not ret: break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = cv2.resize(frame, (size[1], size[0])) # Resize to (height, width) frames.append(frame) if len(frames) < video_length: print(f"Video {video_file} has fewer frames than expected. Expected: {video_length}, Found: {len(frames)} Exiting...") exit(1) cap.release() frames_np = np.array(frames[:video_length]) frames_np = np.transpose(frames_np, (0, 3, 1, 2)) videos_tensor[i] = torch.tensor(frames_np, dtype=torch.float32) / 255.0 return videos_tensor # python scripts/tvideo/mc/common_metrics.py --video_dir1 metrics_table_1/oasis/oasis_official_results_no_demo_1gen15_mineworld_curation --video_dir2 metrics_table_1/frame_16_curation --video_length 15 --channel 3 --size "(224,384)" --output-file test_metrics.json def main(): parser = argparse.ArgumentParser(description="Calculate FVD for two sets of videos.") parser.add_argument("--video_dir1", type=str, required=True, help="Path to the first directory containing videos.") parser.add_argument("--video_dir2", type=str, required=True, help="Path to the second directory containing videos.") parser.add_argument("--video_length", type=int, default=32, help="Number of frames to retain from each video.") parser.add_argument("--channel", type=int, default=3, help="Number of channels in the videos (default: 3 for RGB).") parser.add_argument("--size", type=str, default="(224,384)", help="Size of the video frames (default: 256x256).") parser.add_argument("--output-file", type=str) args = parser.parse_args() args.size = eval(args.size) print("args.size", args.size) number_of_videos = len([f for f in os.listdir(args.video_dir1) if f.endswith(".mp4")]) video_files = [f for f in os.listdir(args.video_dir1) if f.endswith(('.mp4'))] number_of_videos = min(500,len(video_files)) print("number_of_videos", number_of_videos) videos1 = load_videos_to_tensor(args.video_dir1, number_of_videos, args.video_length, args.channel, args.size, video_files) videos2 = load_videos_to_tensor(args.video_dir2, number_of_videos, args.video_length, args.channel, args.size, video_files) print("videos1.shape", videos1.shape, "videos2.shape", videos2.shape) device = torch.device("cuda") print(args.output_file) result = {} result['fvd'] = calculate_fvd(videos1, videos2, device, method='styleganv') # result['fvd'] = calculate_fvd(videos1, videos2, device, method='videogpt') result['ssim'] = calculate_ssim(videos1, videos2) result['psnr'] = calculate_psnr(videos1, videos2) result['lpips'] = calculate_lpips(videos1, videos2, device) lpips_mean = np.mean(list(result['lpips']['value'])) ssim_mean = np.mean(list(result['ssim']['value'])) psnr_mean = np.mean(list(result['psnr']['value'])) fvd_mean = np.mean(list(result['fvd']['value'])) data_item = {"exp_name":args.video_dir1, "fvd":fvd_mean, "lpips":lpips_mean, "ssim":ssim_mean, "psnr":psnr_mean} print(data_item) os.makedirs(os.path.dirname(args.output_file), exist_ok=True) result['mean'] = data_item with open(args.output_file, "w") as f: json.dump(result, f, indent=4) print("results saved to ", args.output_file) if __name__ == "__main__": main() ================================================ FILE: metrics/tabulate_all_results.py ================================================ import argparse import os import json import sys import pandas as pd import numpy as np from rich import print def tabluate_metrics(input_dir,output_path): metrics_list = [] all_files = [f for f in os.listdir(input_dir) if f.endswith('.json')] idm_results = [f for f in all_files if 'idm' in f] fvd_results = [f for f in all_files if 'fvd' in f] exps = set([i.replace("idm_","").replace(".json","") for i in idm_results]) & set([i.replace("fvd_","").replace(".json","") for i in fvd_results]) exps = list(exps) print(f"[bold magenta][Tabulating Evaluation Results][/bold magenta]: Found experiments : {exps}") for exp in exps: idm_file = os.path.join(input_dir, f"idm_{exp}.json") fvd_file = os.path.join(input_dir, f"fvd_{exp}.json") with open(idm_file, 'r') as f: idm_data = json.load(f) with open(fvd_file, 'r') as f: fvd_data = json.load(f) fvd_data = fvd_data["mean"] fvd_data.pop("exp_name", None) idm_data = idm_data["metric_mean_on_task"] metrics_entry = { "experiment": exp, } # merge dict metrics_entry.update(fvd_data) metrics_entry.update(idm_data) metrics_list.append(metrics_entry) # Convert list of metrics to a DataFrame df = pd.DataFrame(metrics_list) # Save the DataFrame to a CSV file df.to_csv(output_path, index=False) print(f"[bold red][Tabulating Evaluation Results End][/bold red] Metrics tabulated and saved to {output_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Tabulate metrics from JSON files") parser.add_argument("--input_dir", type=str, required=True, help="Directory containing JSON metric files") parser.add_argument("--output_path", type=str, required=True, help="Path to save the tabulated metrics CSV file") args = parser.parse_args() tabluate_metrics(args.input_dir, args.output_path) ================================================ FILE: mineworld.py ================================================ import os import sys sys.path.append(os.getcwd()) import gradio as gr from PIL import Image import numpy as np import torch import cv2 from utils import load_model from omegaconf import OmegaConf from argparse import ArgumentParser from collections import deque import tempfile import atexit from torchvision import transforms from einops import rearrange from mcdataset import MCDataset import itertools class Buttons: ATTACK = "attack" BACK = "back" FORWARD = "forward" JUMP = "jump" LEFT = "left" RIGHT = "right" SNEAK = "sneak" SPRINT = "sprint" USE = "use" DROP = "drop" SWAPHANDS = "swapHands" PICKITEM = "pickItem" ALL = [ ATTACK, USE, FORWARD, BACK, LEFT, RIGHT, JUMP, SNEAK, SPRINT, DROP, SWAPHANDS, PICKITEM, # INVENTORY, # ESC, ] + [f"hotbar.{i}" for i in range(1, 10)] KEYBOARD_BUTTON_MAPPING = { "key.keyboard.s" :"back", "key.keyboard.q" :"drop", "key.keyboard.w" :"forward", "key.keyboard.1" :"hotbar.1", "key.keyboard.2" :"hotbar.2", "key.keyboard.3" :"hotbar.3", "key.keyboard.4" :"hotbar.4", "key.keyboard.5" :"hotbar.5", "key.keyboard.6" :"hotbar.6", "key.keyboard.7" :"hotbar.7", "key.keyboard.8" :"hotbar.8", "key.keyboard.9" :"hotbar.9", "key.keyboard.space" :"jump", "key.keyboard.a" :"left", "key.keyboard.d" :"right", "key.keyboard.left.shift" :"sneak", "key.keyboard.left.control" :"sprint", "key.keyboard.f" :"swapHands", } # Template action NOOP_ACTION = { "forward": 0, "back": 0, "left": 0, "right": 0, "jump": 0, "attack": 0, "use": 0, "pickItem": 0, "drop": 0, "sneak": 0, "sprint": 0, "swapHands": 0, "hotbar.1": 0, "hotbar.2": 0, "hotbar.3": 0, "hotbar.4": 0, "hotbar.5": 0, "hotbar.6": 0, "hotbar.7": 0, "hotbar.8": 0, "hotbar.9": 0, "camera": np.array([0, 0]), } ACTION_BUTTON = { "forward": 0, "back": 0, "left": 0, "right": 0, "attack": 0, "sprint": 0, "jump": 0, "use": 0, "drop": 0, "hotbar.1": 0, "pickItem": 0, } FOR_BACK = { "forward": 0, "back": 0, } L_R = { "left": 0, "right": 0, } ATT_USE_DROP = { "attack": 0, "use": 0, "drop": 0, } JUMP_SPR = { "jump": 0, "sprint": 0, } HORBAR = { "hotbar.1": 0, "hotbar.2": 0, "hotbar.3": 0, "hotbar.4": 0, "hotbar.5": 0, "hotbar.6": 0, "hotbar.7": 0, "hotbar.8": 0, "hotbar.9": 0, } safe_globals = {"array": np.array} AGENT_RESOLUTION = (384, 224) CAMERA_SCALER = 360.0 / 2400.0 TOKEN_PER_IMAGE = 336 TOKEN_PER_ACTION = 11 VIDEO_FRAMES = [] GENERATED_FILES = [] frame_cache = [] action_cache = [] last_pos = 0 MC_ACTION_MAP = MCDataset() SHOW_FRAMES = 8 REFERENCE_FRAME = None CONTEXT_LEN = None DIAGD = False WINDOWSIZE = 4 def get_args(): parser = ArgumentParser() parser.add_argument('--scene', type=str, default='./assets/scene.mp4') parser.add_argument('--model_ckpt', type=str, default='./checkpoints/700M_16f.pt') parser.add_argument('--config', type=str, default='./configs/700M_16f.yaml') parser.add_argument('--reference_frame', type=int, default=8) parser.add_argument('--diagd', action='store_true', help='use diagd') parser.add_argument('--window_size', type=int, default=4) args = parser.parse_args() return args def make_action_dict(action_line): action_dict = {'ESC': 0, 'back': 0, 'drop': 0, 'forward': 0, 'hotbar.1': 0, 'hotbar.2': 0, 'hotbar.3': 0, 'hotbar.4': 0, 'hotbar.5': 0, 'hotbar.6': 0, 'hotbar.7': 0, 'hotbar.8': 0, 'hotbar.9': 0, 'inventory': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0, 'swapHands': 0, 'camera': np.array([0, 0]), 'attack': 0, 'use': 0, 'pickItem': 0} if isinstance(action_line, str): action_line = action_line.split(",") action_dict['camera'] = np.array((int(action_line[-2]), int(action_line[-1]))) for act in action_line: if act in Buttons.ALL: action_dict[act] = 1 return action_dict def stack_images(imgs): width, height = imgs[0].size new_im = Image.new('RGB', (4*width, height*2)) for i, im in enumerate(imgs): new_im.paste(im, (width*(i%4), height*(i//4))) return new_im def get_action_line(acts): action_lst = [] for k in acts.keys(): if k != "camera" and acts[k] == 1: action_lst.append(k) action_lst.append(str(acts["camera"][0])) action_lst.append(str(acts["camera"][1])) return ",".join(action_lst) def run_prediction(btns_choices, cam_x_input, cam_y_input): global frame_cache, action_cache, actions_show, images_show, VIDEO_FRAMES, last_pos, CONTEXT_LEN, REFERENCE_FRAME assert len(frame_cache) == len(action_cache)+1 if len(action_cache) >= CONTEXT_LEN - 1: for _ in range(CONTEXT_LEN - REFERENCE_FRAME): frame_cache.popleft() action_cache.popleft() model.transformer.refresh_kvcache() _frame_iter = itertools.islice(frame_cache, 0, len(frame_cache)-1) _act_iter = itertools.islice(action_cache, 0, len(action_cache)) _vis_act = [ torch.cat([img, act], dim=1) for img, act in zip(_frame_iter, _act_iter) ] _vis_act.append(frame_cache[-1]) _vis_act = torch.cat(_vis_act, dim=-1) _, last_pos = model.transformer.prefill_for_gradio(_vis_act) act_dict = make_action_dict(btns_choices) act_dict['camera'] = np.array((int(cam_y_input), int(cam_x_input))) ongoing_act = MC_ACTION_MAP.get_action_index_from_actiondict(act_dict, action_vocab_offset=8192) ongoing_act = torch.tensor(ongoing_act).unsqueeze(0).to("cuda") action_cache.append(ongoing_act) with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16): if DIAGD: next_frame, last_pos = model.transformer.diagd_img_token_for_gradio(input_action=ongoing_act, position_id = last_pos, max_new_tokens=TOKEN_PER_IMAGE, windowsize=4) else: next_frame, last_pos = model.transformer.decode_img_token_for_gradio(input_action=ongoing_act, position_id = last_pos, max_new_tokens=TOKEN_PER_IMAGE + 1) # +1 to fill kvcache last_pos = last_pos[0] next_frame = torch.cat(next_frame, dim=-1).to("cuda") frame_cache.append(next_frame) next_frame = tokenizer.token2image(next_frame) next_frame = Image.fromarray(next_frame) if len(images_show) >= SHOW_FRAMES: images_show.popleft() actions_show.popleft() btns_choices = btns_choices + [np.array((int(cam_y_input), int(cam_x_input)))] actions_show.append(','.join(str(x) for item in btns_choices for x in (item if isinstance(item, np.ndarray) else [item]))) images_show.append(next_frame) VIDEO_FRAMES.append(next_frame) return next_frame, stack_images(images_show), "        ".join([str(x) for x in actions_show]) def run_prediction_n_times(n, btns_1, btns_2, btns_3, btns_4, btns_5, cam_x_input, cam_y_input): btns_choices = btns_1 + btns_2 + btns_3 + btns_4 + btns_5 if cam_x_input is None: cam_x_input = 0 if cam_y_input is None: cam_y_input = 0 if n is None: n = 1 for i in range(n): yield run_prediction(btns_choices, cam_x_input, cam_y_input) def step_pred_source_video_right(video_path, start): global VIDEO_FRAMES, frame_cache, action_cache, REFERENCE_FRAME, CONTEXT_LEN, REFERENCE_FRAME VIDEO_FRAMES.clear(); frame_cache.clear(); action_cache.clear() if start is None or start < 0 or start > MAX_FRAME: start = 0 return step_video(video_path, start, REFERENCE_FRAME) def on_download_button_click(fps=6): if not VIDEO_FRAMES: print("The frames list is empty.") return temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4", dir="/tmp") video_path = temp_file.name temp_file.close() video_writer = cv2.VideoWriter( video_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, AGENT_RESOLUTION ) for frame in VIDEO_FRAMES: frame_np = np.array(frame) frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR) video_writer.write(frame_bgr) video_writer.release() GENERATED_FILES.append(video_path) os.chmod(video_path, 0o644) return video_path def cleanup_files(): for video_path in GENERATED_FILES: try: os.remove(video_path) print(f"Deleted file: {video_path}") except OSError as e: print(f"Error deleting file {video_path}: {e}") atexit.register(cleanup_files) def step_video(video_path, start_frame, frame_count): global images_show, actions_show, frame_cache, action_cache, VIDEO_FRAMES, last_pos, CONTEXT_LEN, REFERENCE_FRAME VIDEO_FRAMES = [] images_show = [] actions_show = [] video = cv2.VideoCapture(video_path) json_data = MC_ACTION_MAP.read_jsonl(video_path[:-4]+".jsonl") frames_tensor = [] action_cache = [] for i in range(start_frame, start_frame + frame_count): step_action = json_data[i] step_action, _ = MC_ACTION_MAP.json_action_to_env_action(step_action) actions_show.append(get_action_line(step_action)) act_index = MC_ACTION_MAP.get_action_index_from_actiondict(step_action, action_vocab_offset=8192) act_index = torch.tensor(act_index).unsqueeze(0) action_cache.append(act_index.to("cuda")) video.set(cv2.CAP_PROP_POS_FRAMES, i) ret, frame = video.read() try: if not ret: raise ValueError(f"frame {i} not ret") cv2.cvtColor(frame, code=cv2.COLOR_BGR2RGB, dst=frame) frame = np.asarray(np.clip(frame, 0, 255), dtype=np.uint8) frame = cv2.resize(frame, AGENT_RESOLUTION, interpolation=cv2.INTER_LINEAR) images_show.append(Image.fromarray(frame)) VIDEO_FRAMES.append(Image.fromarray(frame)) frames_tensor.append(torch.from_numpy(frame)) except Exception as e: print(f"Could not read frame from video {video_path}: {e}") video.release() frames_tensor = torch.stack(frames_tensor, dim=0).to("cuda") frames_tensor = frames_tensor.permute(0, 3, 1, 2).float() / 255.0 frames_tensor = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(frames_tensor) with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16): images_token = tokenizer.tokenize_images(frames_tensor) images_token = rearrange(images_token, '(b t) h w -> (b t) (h w)', b=1) frame_cache = deque(torch.split(images_token, split_size_or_sections=1, dim=0)) action_cache = deque(action_cache) action_cache.pop() images_show = deque(images_show) actions_show = deque(actions_show) actions_show.pop() model.transformer.refresh_kvcache() _frame_iter = itertools.islice(frame_cache, 0, len(frame_cache)-1) _act_iter = itertools.islice(action_cache, 0, len(action_cache)) _vis_act = [ torch.cat([img, act], dim=1) for img, act in zip(_frame_iter, _act_iter) ] _vis_act.append(frame_cache[-1]) _vis_act = torch.cat(_vis_act, dim=-1) _, last_pos = model.transformer.prefill_for_gradio(_vis_act) while len(images_show) > SHOW_FRAMES: images_show.popleft() actions_show.popleft() # WARNING: why dont pop actions return stack_images(images_show), "        ".join([str(x) for x in actions_show]), None css = """ .custom-tab h2 { font-size: 34px; /* 字体大小 */ font-weight: bold; /* 加粗字体 */ color: #ff6600; /* 字体颜色 */ text-shadow: 1px 1px 2px #000000; /* 文字阴影效果 */ } """ if __name__ == "__main__": args = get_args() if args.diagd: DIAGD = True WINDOWSIZE = args.window_size cap = cv2.VideoCapture(args.scene) global MAX_FRAME MAX_FRAME = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 10 config = OmegaConf.load(args.config) REFERENCE_FRAME = args.reference_frame CONTEXT_LEN = int(config.model.params.transformer_config.params.max_position_embeddings / (TOKEN_PER_ACTION + TOKEN_PER_IMAGE)) assert CONTEXT_LEN > REFERENCE_FRAME model = load_model(config, args.model_ckpt, gpu=True, eval_mode=True) tokenizer = model.tokenizer with gr.Blocks(css=css) as demo: with gr.Tab("MineWorld", elem_classes="custom-tab"): source_video_path = gr.Text(value=args.scene, visible=False) with gr.Row(): source_video_actions = gr.Markdown(visible=False) instruction = gr.Markdown("press 'Jump to start frame' to init or restart the game, you can choose different sences by modifed start frame from 0 to 4100", visible=False) with gr.Row(): source_video_images = gr.Image(width=1280, height=360, label="last 8 frames", show_fullscreen_button = True, every=1) with gr.Row(equal_height=True): with gr.Column(min_width=60): vid_frame_start = gr.Number(step=1, value=0, info="start frame", min_width=20, show_label=False, minimum=0, maximum=MAX_FRAME) # vid_num_frames = gr.Number(step=1, value=4, label="num_frames", min_width=50) # with gr.Column(min_width=60): run_steps = gr.Number(step=1, value=1, info="Repeat same action n times", min_width=20, minimum=1, maximum=8, show_label=False) with gr.Column(min_width=60): btn1 = list(FOR_BACK.keys()) btns_1 = gr.CheckboxGroup(choices=btn1, show_label=False) vid_right_btn = gr.Button(value="Jump to start frame", size='sm') with gr.Column(min_width=60): btn2 = list(L_R.keys()) btns_2 = gr.CheckboxGroup(choices=btn2, show_label=False) predict_run_btn = gr.Button(value="Run", variant="primary", size='sm') with gr.Column(min_width=60): btn4 = list(JUMP_SPR.keys()) btns_4 = gr.CheckboxGroup(choices=btn4, show_label=False) download_game_btn = gr.Button("Generate Video", size='sm') with gr.Column(min_width=60): cam_y_input = gr.Number(step=1, value=0, info="camera Y ⬆️(-),0,(+)⬇️", min_width=20, minimum=-90, maximum=90, show_label=False) cam_x_input = gr.Number(step=1, value=0, info="camera X ⬅️(-),0,(+)➡️", min_width=20, minimum=-90, maximum=90, show_label=False) with gr.Row(): with gr.Column(min_width=250): video_display = gr.Video(label="video", width=384, height=224) with gr.Column(min_width=200): predict_result_imgs = gr.Image(label="last generated frame",width=384, height=224) with gr.Column(min_width=200): with gr.Row(): btn3 = list(ATT_USE_DROP.keys()) btns_3 = gr.CheckboxGroup(choices=btn3, show_label=False) with gr.Row(): btn5 = list(HORBAR.keys()) btns_5 = gr.CheckboxGroup(choices=btn5, show_label=False) vid_right_btn.click(fn=step_pred_source_video_right, inputs=[source_video_path, vid_frame_start], outputs=[source_video_images, source_video_actions, predict_result_imgs]) predict_run_btn.click(fn=run_prediction_n_times, inputs=[run_steps, btns_1, btns_2, btns_3, btns_4, btns_5, cam_x_input, cam_y_input], outputs=[predict_result_imgs, source_video_images, source_video_actions],) download_game_btn.click(fn=on_download_button_click, inputs=[], outputs=video_display) demo.load(fn=step_pred_source_video_right, inputs=[source_video_path, gr.Number(value=25, visible=False)], outputs=[source_video_images, source_video_actions, predict_result_imgs]) demo.queue() demo.launch(server_name="0.0.0.0", max_threads=256, server_port=7861, share=True) ================================================ FILE: requirements.txt ================================================ torch==2.6.0 torchvision==0.21.0 omegaconf==2.3.0 transformers==4.48.1 opencv-python==4.11.0.86 attrs==25.3.0 diffusers==0.32.2 gradio==5.24.0 einops==0.8.1 diffusers==0.32.2 scipy==1.15.2 torch-fidelity==0.3.0 gym3==0.3.3 gym==0.26.2 scikit-learn==1.6.1 ================================================ FILE: scripts/compute_metrics.sh ================================================ #!/bin/bash VIDEO_RESULTS_ROOT_DEFAULT="videos" METRICS_ROOT_DEFAULT="metrics_log" JSONL_PATH_DEFAULT="validation/validation" IDM_CKPT_DIR="checkpoints/IDM" VIDEO_RESULTS_ROOT=${1:-$VIDEO_RESULTS_ROOT_DEFAULT} METRICS_ROOT=${2:-$METRICS_ROOT_DEFAULT} JSONL_PATH=${3:-$JSONL_PATH_DEFAULT} echo "VIDEO_RESULTS_ROOT = $VIDEO_RESULTS_ROOT" echo "METRICS_ROOT = $METRICS_ROOT" echo "JSONL_PATH = $JSONL_PATH" # Loop through each subdirectory in VIDEO_RESULTS_ROOT for video_dir1 in "$VIDEO_RESULTS_ROOT"/*/; do # Skip the 'metrics' directory if [ -d "$video_dir1" ] && [ "$(basename "$video_dir1")" != "metrics" ]; then # Construct the output file name based on the video directory name fvd_output_file="$METRICS_ROOT/fvd_$(basename "$video_dir1").json" echo $fvd_output_file # Run the python command for each video directory python metrics/common_metrics.py --video_dir2 $JSONL_PATH --video_length 15 --channel 3 --size "(224,384)" \ --video_dir1 "$video_dir1" --output-file "$fvd_output_file" idm_output_file="$METRICS_ROOT/idm_$(basename "$video_dir1").json" python metrics/IDM/inverse_dynamics_model.py --weights $IDM_CKPT_DIR/"4x_idm.weights" \ --infer-demo-num 1 --n-frames 15 \ --model $IDM_CKPT_DIR/"4x_idm.model" --video-path $video_dir1 \ --output-file "$idm_output_file" \ --jsonl-path $JSONL_PATH fi done python metrics/tabulate_all_results.py --input_dir $METRICS_ROOT --output_path $METRICS_ROOT/latest_metrics.csv ================================================ FILE: scripts/inference_16f_models.sh ================================================ DATA_ROOT="validation/validation" ########################### #### Inference 1200M models ########################### CONFIG="configs/1200M_16f.yaml" CKPT_PATH="checkpoints/1200M_16f.ckpt" OUTPUT_PATH="./videos/1200M_16f200_demo1gen15_naive" python inference.py \ --data_root $DATA_ROOT \ --config $CONFIG \ --model_ckpt $CKPT_PATH \ --demo_num 1 --frames 15 \ --accelerate-algo 'naive' \ --top_p 0.8 \ --output_dir $OUTPUT_PATH ================================================ FILE: scripts/setup_metrics.sh ================================================ # clone common_metrics_on_video_quality repository git clone git@github.com:CIntellifusion/common_metrics_on_video_quality.git # get IDM weights mkdir -p checkpoints/IDM wget https://openaipublic.blob.core.windows.net/minecraft-rl/idm/4x_idm.model -O checkpoints/IDM/4x_idm.model wget https://openaipublic.blob.core.windows.net/minecraft-rl/idm/4x_idm.weights -O checkpoints/IDM/4x_idm.weights ================================================ FILE: utils.py ================================================ import torch import importlib from rich import print from typing import Union import numpy as np import os def print0(*args, **kwargs): print(*args, **kwargs) # python -m rich.color def tensor_to_uint8(tensor): tensor = torch.clamp(tensor, -1.0, 1.0) tensor = (tensor + 1.0) / 2.0 tensor = (tensor.cpu().numpy() * 255).astype(np.uint8) return tensor def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def instantiate_from_config(config): if not "target" in config: raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def load_model_from_config(config, sd, gpu=True, eval_mode=True): model = instantiate_from_config(config) if sd is not None: missing, unexpected = model.load_state_dict(sd, strict=False) if len(missing) != 0: raise ValueError(f"Missing keys: {missing}") if gpu: model.cuda() if eval_mode: model.eval() return {"model": model} def get_valid_dirs(dir1: str, dir2: Union[None, str] = None, dir3: Union[None, str] = None) -> Union[None, str]: if (dir1 is not None) and os.path.isdir(dir1): return dir1 elif (dir2 is not None) and os.path.isdir(dir2): return dir2 elif (dir3 is not None) and os.path.isdir(dir3): return dir3 else: return None def get_valid_paths(path1: str, path2: Union[None, str] = None, path3: Union[None, str] = None) -> Union[None, str]: if (path1 is not None) and os.path.isfile(path1): return path1 elif (path2 is not None) and os.path.isfile(path2): return path2 elif (path3 is not None) and os.path.isfile(path3): return path3 else: return None def load_model(config, ckpt, gpu, eval_mode): if str(ckpt).endswith(".bin"): weight = torch.load(ckpt) elif ckpt: weight = torch.load(ckpt, map_location="cpu")["state_dict"] model = load_model_from_config(config.model, weight, gpu=gpu, eval_mode=eval_mode)["model"] model.load_state_dict(weight, strict=False) model.to(torch.float16) return model ================================================ FILE: vae.py ================================================ import torch import torch.nn as nn import diffusers from safetensors.torch import load_file as load_safetensors from utils import print0, get_valid_paths, tensor_to_uint8 class VAE(nn.Module): def __init__(self, config_path: str, ckpt_path: str, ): super().__init__() config_path = get_valid_paths(config_path) print0(f"[bold magenta]\[VAE][/bold magenta] Loading VQGAN from {config_path}") self.model = diffusers.VQModel.from_config(config_path) ckpt_path = get_valid_paths(ckpt_path) print0(f"[bold magenta]\[VAE][/bold magenta] Use ckpt_path: {ckpt_path}") self.init_from_ckpt(ckpt_path) def init_from_ckpt( self, path: str ) -> None: if path.endswith("ckpt"): ckpt = torch.load(path, map_location="cpu", weights_only=False) if "state_dict" in ckpt: weights = ckpt["state_dict"] else: weights = ckpt elif path.endswith("safetensors"): weights = load_safetensors(path) else: raise NotImplementedError missing, unexpected = self.load_state_dict(weights, strict=False) print0( f"[bold magenta]\[tvae.models.amused_vqvae][AutoencodingLegacy][/bold magenta] Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" ) if len(missing) > 0: print0(f"[bold magenta]\[tvae.models.amused_vqvae][AutoencodingLegacy][/bold magenta] Missing Keys: {missing}") # if len(unexpected) > 0: # print0(f"[bold magenta]\[tvae.models.amused_vqvae][AutoencodingLegacy][/bold magenta] Unexpected Keys: {unexpected}") @torch.no_grad() def tokenize_images(self, x: torch.Tensor, sane_index_shape: bool = True): h = self.model.encoder(x) h = self.model.quant_conv(h) if sane_index_shape: orig_sane_index_shape = self.model.quantize.sane_index_shape self.model.quantize.sane_index_shape = True z_q, loss, (perplexity, min_encodings, min_encoding_indices) = self.model.quantize(h) if sane_index_shape: self.model.quantize.sane_index_shape = orig_sane_index_shape return min_encoding_indices # yang ye @torch.no_grad() def token2image(self, tokens): assert tokens.max() < 8192, f"code max value is {tokens.max()}" shape = (1, 14, 24, 64) with torch.autocast(device_type='cuda', dtype=torch.float32): quant = self.model.quantize.get_codebook_entry(tokens, shape) quant2 = self.model.post_quant_conv(quant) dec = self.model.decoder(quant2) img = tensor_to_uint8(dec[0]).transpose(1, 2, 0) return img