Full Code of microsoft/mineworld for AI

main 9f49efbcc68a cached
40 files
243.3 KB
63.5k tokens
378 symbols
1 requests
Download .txt
Showing preview only (256K chars total). Download the full file or copy to clipboard to get everything.
Repository: microsoft/mineworld
Branch: main
Commit: 9f49efbcc68a
Files: 40
Total size: 243.3 KB

Directory structure:
gitextract_mrwf1vu0/

├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── configs/
│   ├── 1200M_16f.yaml
│   ├── 1200M_32f.yaml
│   ├── 300M_16f.yaml
│   ├── 700M_16f.yaml
│   └── 700M_32f.yaml
├── diagonal_decoding.py
├── inference.py
├── lvm.py
├── mcdataset.py
├── metrics/
│   ├── IDM/
│   │   ├── inverse_dynamics_model.py
│   │   └── lib/
│   │       ├── __init__.py
│   │       ├── action_head.py
│   │       ├── action_mapping.py
│   │       ├── actions.py
│   │       ├── impala_cnn.py
│   │       ├── masked_attention.py
│   │       ├── minecraft_util.py
│   │       ├── misc.py
│   │       ├── mlp.py
│   │       ├── normalize_ewma.py
│   │       ├── policy.py
│   │       ├── scaled_mse_head.py
│   │       ├── torch_util.py
│   │       ├── tree_util.py
│   │       ├── util.py
│   │       └── xf.py
│   ├── common_metrics.py
│   └── tabulate_all_results.py
├── mineworld.py
├── requirements.txt
├── scripts/
│   ├── compute_metrics.sh
│   ├── inference_16f_models.sh
│   └── setup_metrics.sh
├── utils.py
└── vae.py

================================================
FILE CONTENTS
================================================

================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Microsoft Open Source Code of Conduct

This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).

Resources:

- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns


================================================
FILE: LICENSE
================================================
    MIT License

    Copyright (c) Microsoft Corporation.

    Permission is hereby granted, free of charge, to any person obtaining a copy
    of this software and associated documentation files (the "Software"), to deal
    in the Software without restriction, including without limitation the rights
    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    copies of the Software, and to permit persons to whom the Software is
    furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included in all
    copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    SOFTWARE


================================================
FILE: README.md
================================================
<div align="center">

# MineWorld <br> <sub>A Real-time Interactive World Model on Minecraft</sub>

[![arXiv](https://img.shields.io/badge/arXiv-Paper-red?logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2504.08388) &ensp; [![Project](https://img.shields.io/badge/Project-Page-blue?logo=homepage&logoColor=white)](https://aka.ms/mineworld) &ensp; [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-yellow)](https://huggingface.co/microsoft/mineworld)
</div>

We introduce MineWorld, an interactive world model on Minecraft that brings several key advancements over existing approaches: 
* 🕹️ **High generation quality**. Built on a visual-action autoregressive Transformer, MineWorld generates coherent, high-fidelity frames conditioned on both visuals and actions. 
* 🕹️ **Strong controllability**. We propose benchmarks for the action-following capacity, where MineWorld shows precise and consistent behavior. 
* 🕹️ **Fast inference speed**. With Diagonal Decoding, MineWorld achieves a generation rate of 4 to 7 frames per second, enabling real-time interaction in open-ended game environments. 

https://github.com/user-attachments/assets/2f5b4740-badd-453c-970d-061abd367f82

## 🔥 News
* May, 2025: The model checkpoints in the [Huggingface repo](https://huggingface.co/microsoft/mineworld) have been temporally taken down.
* April, 2025: 🚀 [MineWorld](https://github.com/microsoft/mineworld) was released!
* March, 2025: 🚀 The paper of [Diagonal Decoding](https://arxiv.org/pdf/2503.14070) was released!

## 🔧 Setup
1. Clone this repository and navigate to MineWorld folder:
```bash
git clone https://github.com/microsoft/mineworld.git
cd mineworld
```
2. We provide an `requirements.txt` file for setting up a pip environment.
```bash
# 1. Prepare conda environment
conda create -n mineworld python=3.10
# 2. Activate the environment
conda activate mineworld
# 3. install our environment
pip3 install -r requirements.txt
```

We recommend using high-end GPU for inference. We have done all testing and development using A100 and H100 GPU. 


## 🎈 Checkpoints
Download pre-trained models [here](https://huggingface.co/microsoft/mineworld). Each checkpoint has a corresponding config file with the same name in the `configs` folder in this repository. All models share the same vae checkpoint and config. The data structure is as follows:
```
└── checkpoints
    ├── 300M_16f.ckpt
    ├── 700M_16f.ckpt
    ├── 700M_32f.ckpt
    ├── 1200M_16f.ckpt
    └── 1200M_32f.ckpt
    └── vae
        ├── config.json
        └── vae.ckpt
└── validation
    └── validation.zip
└── gradio_scene
    ├── scene.mp4
    └── scene.jsonl
```

## 🚀 Inference
We provide two ways to use our model: interacting with it in a web demo, and running locally to reproduce the evaluation results in our paper. In addition to download the checkpoints and place them in the `checkpoints` folder, it is also required to download `scene.mp4` and `scene.jsonl` when running the web demo. Make sure they are placed in the same directory.

### Run Web Demo

To launch the webpage game, run the following command:
```bash
python mineworld.py --scene "path/to/scene.mp4"    
    --model_ckpt "path/to/ckpt" 
    --config "path/to/config" 
```

![image](assets/demo.png)

Once the demo is running, you can access the website through the local URL or the public URL displayed in the command line. Initialization and the first action may take some time due to compilation.

You can specify a reference frame using the `--reference_frame` option, which should be larger than `4` and smaller than the context length of the model (i.e., `16` or `32` depending on the model utilized). A higher reference frame number generally corresponds to better visual quality. Once the initial state has been set, perform the game actions by selecting options in each chatbox.
The game progresses when pressing the "Run" button, displaying the last `8` frames and the most recent frame separately. Players can also set an action count to repeat an action multiple times.

Explanations to the buttons in the web demo are as follows:
```
Start frame: select a frame in scene.mp4 with its frame index
Jump to start frame: use the selected frame as the initial state
Camera `X` and `Y`: control the camera movements between `-90` and `90` degrees
Other action buttons: same as the actions in Minecraft 
Generate video: save previous game progress
```

### Run Local Inference

To run inference locally, use the following command:

```bash
python inference.py \
        --data_root "/path/to/validation/dataset" \
        --model_ckpt "path/to/ckpt" \
        --config "path/to/config" \
        --demo_num 1 \
        --frames 15 \
        --accelerate-algo 'naive' \
        --top_p 0.8 \
        --output_dir "path/to/output"
```

Check `scripts/inference_16f_models.sh` for examples. To switch between naive autoregressive decoding and diagonal decoding, change the command `--accelerate-algo` to `naive` and `image_diagd` correspondingly. 

After the inference of a set of videos, you can compute the metrics and reproduce the numerical results in our paper, check and run the following scripts:
```bash 
bash scripts/setup_metrics.sh # only required in the first time
bash scripts/compute_metrics.sh 
```

The evalution outputs will have the following structure: 
```
└── videos 
    ├── inference_setting1
        ├── clip_1.mp4
        └── clip_1.json
    ├── inference_setting2
        ├── clip_1.mp4
        └── clip_1.json
└── metrics_log
    ├── fvd_inference_setting1.json
    ├── fvd_inference_setting2.json
    ├── idm_inference_setting1.json
    ├── idm_inference_setting2.json
    └── latest_metrics.csv 
```
All results will be aggregated into `metrics_log/latest_metrics.csv`.

## 💡 Intended Uses

Our model is solely trained in the Minecraft game domain. As a world model, an initial image in the game scene will be provided, and the users should select an action from the action list. Then the model will generate the next scene that takes place the selected action.


## 🪧 Out-of-scope Uses

Our models are not specifically designed for any tasks or scenarios other than the Minecraft model. 

Developers should expect failures in generation results regarding the out-of-scope scenarios. 

Developers should be aware of and adhere to applicable laws or regulations (including privacy, trade compliance laws, etc.) that are relevant to their use case, and evaluate and mitigate for privacy, safety, and fairness before using within a specific downstream use case, particularly for high-risk scenarios.

## 🤖️ Risks and Limitations 

Some of the limitations of this model to be aware of include: 
* Quality of Service: MineWorld is trained solely on Minecraft, so it cannot generate results for other video domains (such as internet video). And the model cannot generate videos with higher resolution.
* Information Reliability: MineWorld is trained on videos with a fixed resolution, therefore the results may lose detailed information due to the low resolution. 
* MineWorld inherits any biases, errors, or omissions characteristic of its training data, which may be amplified by any AI-generated interpretations. 
* MineWorld was developed for research and experimental purposes. Further testing and validation are needed before considering its application in commercial or real-world scenarios. 
* The input of other images than Minecraft will result in incoherent imagery being created and should not be attempted.
* Users are responsible for sourcing their datasets legally and ethically. This could include securing appropriate copy rights, ensuring consent for use of audio/images, and/or the anonymization of data prior to use in research.

## ✏️ BibTeX

```bibtex
@article{guo2025mineworld,
  title={MineWorld: a Real-Time and Open-Source Interactive World Model on Minecraft}, 
  author={Guo, Junliang and Ye, Yang and He, Tianyu and Wu, Haoyu and Jiang, Yushu and Pearce, Tim and Bian, Jiang}
  year={2025},
  journal={arXiv preprint arXiv:2504.08388},
}
```

## 🤗 Acknowledgments
This codebase borrows code from [VPT](https://github.com/openai/Video-Pre-Training) and [generative-models](https://github.com/Stability-AI/generative-models). We thank them for their efforts and innovations, which have made the development process more efficient and convenient.

Thank you to everyone who contributed their wisdom and efforts to this project.

## ☎️ Contact

We welcome feedback and collaboration from our audience. If you have suggestions, questions, or observe unexpected/offensive behavior in our technology, please contact us through `tianyuhe AT microsoft.com`.

## 📄 Contributing

This project welcomes contributions and suggestions.  Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.


## 📍 Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 
trademarks or logos is subject to and must follow 
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
Any use of third-party trademarks or logos are subject to those third-party's policies.


================================================
FILE: SECURITY.md
================================================
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.9 BLOCK -->

## Security

Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).

If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.

## Reporting Security Issues

**Please do not report security vulnerabilities through public GitHub issues.**

Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).

If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com).  If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).

You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 

Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:

  * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
  * Full paths of source file(s) related to the manifestation of the issue
  * The location of the affected source code (tag/branch/commit or direct URL)
  * Any special configuration required to reproduce the issue
  * Step-by-step instructions to reproduce the issue
  * Proof-of-concept or exploit code (if possible)
  * Impact of the issue, including how an attacker might exploit the issue

This information will help us triage your report more quickly.

If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.

## Preferred Languages

We prefer all communications to be in English.

## Policy

Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).

<!-- END MICROSOFT SECURITY.MD BLOCK -->


================================================
FILE: SUPPORT.md
================================================
# TODO: The maintainer of this repo has not yet edited this file

**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?

- **No CSS support:** Fill out this template with information about how to file issues and get help.
- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
- **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.

*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*

# Support

## How to file issues and get help  

This project uses GitHub Issues to track bugs and feature requests. Please search the existing 
issues before filing new issues to avoid duplicates.  For new issues, file your bug or 
feature request as a new Issue.

For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 
FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
CHANNEL. WHERE WILL YOU HELP PEOPLE?**.

## Microsoft Support Policy  

Support for this **PROJECT or PRODUCT** is limited to the resources listed above.


================================================
FILE: configs/1200M_16f.yaml
================================================
model:
  target: lvm.LlamaLVM
  params:
    model_class: lvm.LlamaForCausalLM
    tokenizer_config:
      target: vae.VAE
      params:
        config_path: checkpoints/vae/config.json
        ckpt_path: checkpoints/vae/vae.ckpt        
    transformer_config:
      target: transformers.LlamaConfig
      params:
        max_position_embeddings: 5552
        hidden_size: 2048
        intermediate_size: 8192
        num_attention_heads: 32
        num_key_value_heads: 4
        num_hidden_layers: 20
        rope_theta: 10000.0
        torch_dtype: bfloat16
        rms_norm_eps: 1.0e-05
        vocab_size: 8262
        attention_bias: false
        mlp_bias: false

================================================
FILE: configs/1200M_32f.yaml
================================================
model:
  target: lvm.LlamaLVM
  params:
    model_class: lvm.LlamaForCausalLM
    tokenizer_config:
      target: vae.VAE
      params:
        config_path: checkpoints/vae/config.json
        ckpt_path: checkpoints/vae/vae.ckpt   
    transformer_config:
      target: transformers.LlamaConfig
      params:
        max_position_embeddings: 11104
        hidden_size: 2048
        intermediate_size: 8192
        num_attention_heads: 32
        num_key_value_heads: 4
        num_hidden_layers: 20
        rope_theta: 10000.0
        torch_dtype: bfloat16
        rms_norm_eps: 1.0e-05
        vocab_size: 8262
        attention_bias: false
        mlp_bias: false

================================================
FILE: configs/300M_16f.yaml
================================================
model:
  target: lvm.LlamaLVM
  params:
    model_class: lvm.LlamaForCausalLM
    tokenizer_config:
      target: vae.VAE
      params:
        config_path: checkpoints/vae/config.json
        ckpt_path: checkpoints/vae/vae.ckpt     
    transformer_config:
      target: transformers.LlamaConfig
      params:
        max_position_embeddings: 5552
        hidden_size: 1024
        intermediate_size: 4096
        num_attention_heads: 16
        num_key_value_heads: 4
        num_hidden_layers: 20
        initializer_range: 0.02
        rope_theta: 10000.0
        torch_dtype: bfloat16
        rms_norm_eps: 1.0e-05
        vocab_size: 8262
        attention_bias: false
        mlp_bias: false
        token_num: 347
        image_num: 336
        frame: 16

================================================
FILE: configs/700M_16f.yaml
================================================
model:
  target: lvm.LlamaLVM
  params:
    model_class: lvm.LlamaForCausalLM
    tokenizer_config:
      target: vae.VAE
      params:
        config_path: checkpoints/vae/config.json
        ckpt_path: checkpoints/vae/vae.ckpt       
    transformer_config:
      target: transformers.LlamaConfig
      params:
        max_position_embeddings: 5552
        hidden_size: 2048
        intermediate_size: 4096
        num_attention_heads: 32
        num_key_value_heads: 4
        num_hidden_layers: 20
        rope_theta: 10000.0
        torch_dtype: bfloat16
        rms_norm_eps: 1.0e-05
        vocab_size: 8262
        attention_bias: false
        mlp_bias: false

================================================
FILE: configs/700M_32f.yaml
================================================
model:
  target: lvm.LlamaLVM
  params:
    model_class: lvm.LlamaForCausalLM
    tokenizer_config:
      target: vae.VAE
      params:
        config_path: checkpoints/vae/config.json
        ckpt_path: checkpoints/vae/vae.ckpt       
    transformer_config:
      target: transformers.LlamaConfig
      params:
        max_position_embeddings: 11104
        hidden_size: 2048
        intermediate_size: 4096
        num_attention_heads: 32
        num_key_value_heads: 4
        num_hidden_layers: 20
        rope_theta: 10000.0
        torch_dtype: bfloat16
        rms_norm_eps: 1.0e-05
        vocab_size: 8262
        attention_bias: false
        mlp_bias: false



================================================
FILE: diagonal_decoding.py
================================================
import torch
from typing import Optional
from torch.nn.attention import SDPBackend

def sample_top_k(logits, temperature: float = 1.0, top_k: Optional[int] = None, vocab_size=8192):
    """
    Sample from the logits using top-k sampling.
    Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
    """
    # logits: [batch_size, seq_len, vocab_size]
    if temperature == 0.0:
        idx_next = torch.argmax(logits[:, -1, :vocab_size], dim=-1, keepdim=True)
    else:
        probs = logits_to_probs(logits[:, -1, :vocab_size], temperature, top_k)
        idx_next = multinomial_sample_one_no_sync(probs)
    return idx_next

def multinomial_sample_one_no_sync(probs_sort, dtype=torch.int):
    """
    Multinomial sampling without a cuda synchronization.
    Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
    """
    q = torch.empty_like(probs_sort).exponential_(1)
    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=dtype)

def logits_to_probs(
    logits,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
):
    logits = logits / max(temperature, 1e-5)

    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs

def sample_top_p(logits, temperature, top_p, vocab_size=8192):
    probs = torch.softmax(logits[:, -1, :vocab_size] / temperature, dim=-1)
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > top_p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = multinomial_sample_one_no_sync(probs_sort, dtype=torch.int64)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token

def sample_n_top_p(logits, temperature, top_p, vocab_size=8192):
    probs = torch.softmax(logits[:, :, :vocab_size] / temperature, dim=-1)
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > top_p
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = multinomial_sample_one_no_sync(probs_sort, dtype=torch.int64)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token

def sample_n_top_k(logits, temperature: float = 1.0, top_k: Optional[int] = None, vocab_size=8192):
    if temperature == 0.0:
        # Modify for multiple logits (n items)
        idx_next = torch.argmax(logits[:, :, :vocab_size], dim=-1, keepdim=True)  # Use all n logits for top-k
        probs = None
    else:
        probs = logits_to_n_probs(logits[:, :, :vocab_size], temperature, top_k)
        idx_next = multinomial_sample_one_no_sync(probs)

    return idx_next

def logits_to_n_probs(
    logits,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
):
    logits = logits / max(temperature, 1e-5)

    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)), dim=-1)
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs

def decode_one_token(
    model,
    input_ids: torch.Tensor,
    position_ids: torch.Tensor,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
):
    """
    Decode a single token from the autoregressive model.
    """
    logits = model(input_ids=input_ids, position_ids=position_ids)
    if top_p is not None:
        return sample_top_p(logits, temperature=temperature, top_p=top_p)
    else:
        return sample_top_k(logits, temperature=temperature, top_k=top_k)
    
def decode_some_token(
    model,
    input_ids: torch.Tensor,
    position_ids: torch.Tensor,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
):
    """
    Decode multi token from the autoregressive model.
    """
    logits = model(input_ids=input_ids, position_ids=position_ids)
    if top_p is not None:
        return sample_n_top_p(logits, temperature=temperature, top_p=top_p)
    else:
        return sample_n_top_k(logits, temperature=temperature, top_k=top_k)
    
def decode_n_tokens(
    model,
    input_ids: torch.Tensor,
    position_ids: torch.Tensor,
    num_generate_tokens: int,
    temperature: float = 1.0,
    top_p: Optional[float] = 0.8,
    top_k: Optional[int] = None,
    decode_one_token_function=decode_one_token,
    pixnum: int = 336,
    actnum: int = 11,
    **kwargs,
):
    """
    Decode n tokens from the autoregressive model.
    Adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
    """
    new_tokens = [input_ids]
    pos_ = position_ids
    assert (
        top_p is None or top_k is None
    ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"

    for t in range(num_generate_tokens):
        with torch.nn.attention.sdpa_kernel(
            SDPBackend.MATH
        ):  # Actually better for Inductor to codegen attention here
            next_token = decode_one_token_function(
                model,
                input_ids=input_ids,
                position_ids=position_ids,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
            )
            pos_ += 1
            position_ids = pos_
            new_tokens.append(next_token.clone())
            input_ids = next_token.clone()

            if (pos_ - pixnum + 1) % (actnum + pixnum) == 0 and t+2 < num_generate_tokens:
                action = kwargs["action"][ (t+2) // pixnum ]
                input_ids = torch.cat((input_ids, action), dim=-1)
                position_ids = torch.tensor([pos_ + _ for _ in range(actnum+1)], dtype=torch.long, device="cuda")
                pos_ += actnum

    return new_tokens

def decode_n_tokens_for_gradio(
    model,
    input_ids: torch.Tensor,
    position_ids: torch.Tensor,
    num_generate_tokens: int,
    temperature: float = 1.0,
    top_p: Optional[float] = 0.8,
    top_k: Optional[int] = None,
    decode_one_token_function=decode_one_token,
):
    """
    Decode n tokens from the autoregressive model.
    Adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
    """
    new_tokens = []
    assert (
        top_p is None or top_k is None
    ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
    position_id = position_ids[-1].unsqueeze(0)
    assert num_generate_tokens % 336 == 1, "should be pixnum x n + 1 to fill kvcache"
    for t in range(num_generate_tokens):
        with torch.nn.attention.sdpa_kernel(
            SDPBackend.MATH
        ):  # Actually better for Inductor to codegen attention here
            next_token = decode_one_token_function(
                model,
                input_ids=input_ids,
                position_ids=position_ids,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
            )
            position_id += 1
            position_ids = position_id
            new_tokens.append(next_token.clone())
            input_ids = next_token.clone()
    return new_tokens[:-1], position_id

def prefill(
    model,
    input_ids: torch.Tensor = None,
    position_ids: torch.Tensor = None,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: Optional[float] = 0.8,
    **kwargs,
):
    logits = model(input_ids=input_ids, position_ids=position_ids)
    # Only top-p or top-k can be provided
    assert (
        top_p is None or top_k is None
    ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
    if top_p is not None:
        return sample_top_p(logits, temperature=temperature, top_p=top_p)
    else:
        return sample_top_k(logits, temperature=temperature, top_k=top_k)

def img_diagd_prepare_inputs(
    ongoing_row_list,
    row_token_num,
    ongoing_input,
    prompt,
    imagenum,
    pixnum: int = 336,
    actnum: int = 11,
    columnnum: int = 24,
    promptlen: int = 347,
    **kwargs
):
    position_ids = []
    
    for i in ongoing_row_list:
        global_idx = promptlen + i * columnnum + row_token_num[i] - 1 + (imagenum - 1) * (pixnum + actnum)
        position_ids.append(global_idx)

    if row_token_num[ongoing_row_list[-1]] == 0:
        append_policy = kwargs.get("append_policy", True)
        if append_policy:
            idx_in_input_ids = ongoing_row_list[-1] * columnnum - 1
            ongoing_input.append(prompt[:, idx_in_input_ids].unsqueeze(-1))
        else:
            ongoing_input.append(ongoing_input[-1])

    input_ids = torch.cat(ongoing_input, dim=1)
    position_ids = torch.tensor(position_ids, device="cuda")

    return input_ids, position_ids

def img_diagd_decode_n_tokens(
    model,
    input_ids: torch.Tensor,
    position_ids: torch.Tensor,
    num_generate_tokens: int,
    temperature: float = 1.0,
    top_p: Optional[float] = 0.8,
    top_k: Optional[int] = None,
    decode_some_token_function=decode_some_token,
    pixnum: int = 336,
    actnum: int = 11,
    columnnum: int = 24,
    rownum: int = 14,
    windowsize: int = 2,
    promptlen: int = 347,
    **kwargs,
):
    assert (
        top_p is None or top_k is None
    ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"

    imagenum = 1
    cur_len = 1
    num_generate_tokens += 1
    prompt = kwargs.pop("prompt", None) 
    new_tokens = [input_ids.clone()]
    row_token_num = torch.zeros((rownum,), dtype=torch.long, device="cuda")
    row_token_num[0] += 1 
    ongoing_row_list = [0]
    ongoing_input = [input_ids.clone()]

    while True:
        if cur_len >= num_generate_tokens:
            break

        if cur_len % pixnum == 0 :#and image_start_token_id_index is None: 
            imagenum += 1
            action = kwargs["action"][cur_len // pixnum]
            ongoing_input.append(action)
            input_id = torch.cat(ongoing_input, dim=-1)
            position_ids = torch.arange(imagenum * (pixnum+actnum) - actnum - 1, imagenum * (pixnum+actnum), device="cuda")

        image_token_num = cur_len % pixnum

        if image_token_num == 1 and row_token_num[0] == windowsize:
            ongoing_row_list.append(1)

        if image_token_num >= 1:
            input_id, position_ids = img_diagd_prepare_inputs(ongoing_row_list=ongoing_row_list, ongoing_input = ongoing_input, imagenum=imagenum, row_token_num=row_token_num, promptlen=promptlen, prompt=prompt,**kwargs)  
            
        num_new_tokens = input_id.shape[1] if len(ongoing_row_list) > 0 else 1
        with torch.nn.attention.sdpa_kernel(
            SDPBackend.MATH
        ):  # Actually better for Inductor to codegen attention here
            next_token = decode_some_token_function(
                model,
                input_ids=input_id,
                position_ids=position_ids,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
            )
            ongoing_input = []
        if len(ongoing_row_list) == 0:
            cur_len += 1
            ongoing_input.append(next_token[:,-1].clone())
            new_tokens.append(next_token[:,-1].clone())
            ongoing_row_list.append(0)
            row_token_num[0] += 1 
        else:
            need_remove_row = None
            cur_len += num_new_tokens
            for i in range(num_new_tokens):
                position_in_new_tokens = torch.sum(row_token_num[:(ongoing_row_list[i] + 1)], dim=0) + (imagenum - 1) * pixnum 
                new_tokens.insert(position_in_new_tokens, next_token[:,i].clone())
                ongoing_input.append(next_token[:,i].clone())
                row_token_num[ongoing_row_list[i]] += 1

                if row_token_num[ongoing_row_list[i]] == windowsize and ongoing_row_list[i] < rownum - 1:
                    ongoing_row_list.append(ongoing_row_list[i]+1)

                elif ongoing_row_list[i] == rownum - 1 and row_token_num[ongoing_row_list[i]] == columnnum:
                    row_token_num = torch.zeros((rownum,), dtype=torch.long, device="cuda")
                    ongoing_row_list = []
                    ongoing_input = [next_token[:,i]]
                    need_remove_row = None
                    break

                if row_token_num[ongoing_row_list[i]] == columnnum: ## this row is done
                    ongoing_input.pop()
                    need_remove_row = ongoing_row_list[i]

            if need_remove_row is not None:
                ongoing_row_list.remove(need_remove_row)
    return new_tokens

def img_diagd_prepare_inputs_for_gradio(
    ongoing_row_list,
    row_token_num,
    ongoing_input,
    pixnum: int = 336,
    actnum: int = 11,
    columnnum: int = 24,
    promptlen: int = 347,
):
    position_ids = []
    
    for i in ongoing_row_list:
        global_idx = promptlen + i * columnnum + row_token_num[i] - 1
        position_ids.append(global_idx)

    if row_token_num[ongoing_row_list[-1]] == 0:
        ongoing_input.append(ongoing_input[-1])

    input_ids = torch.cat(ongoing_input, dim=1)
    position_ids = torch.tensor(position_ids, device="cuda")

    return input_ids, position_ids

def img_diagd_decode_n_token_for_gradio(
    model,
    input_ids: torch.Tensor,
    position_ids: torch.Tensor,
    num_generate_tokens: int,
    temperature: float = 1.0,
    top_p: Optional[float] = 0.8,
    top_k: Optional[int] = None,
    decode_some_token_function=decode_some_token,
    pixnum: int = 336,
    columnnum: int = 24,
    rownum: int = 14,
    windowsize: int = 2,
):
    assert (
        top_p is None or top_k is None
    ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"

    cur_len = 0
    promptlen = position_ids[-1] + 1
    
    new_tokens = []
    row_token_num = torch.zeros((rownum,), dtype=torch.long, device="cuda")
    ongoing_row_list = []
    ongoing_input = []
    
    while True:
        if cur_len == num_generate_tokens:
            break

        image_token_num = cur_len

        if image_token_num == 1 and row_token_num[0] == windowsize:
            ongoing_row_list.append(1)
        if image_token_num == 0:
            input_id = input_ids

        if image_token_num >=1:
            input_id, position_ids = img_diagd_prepare_inputs_for_gradio(ongoing_row_list=ongoing_row_list, ongoing_input = ongoing_input, row_token_num=row_token_num, promptlen=promptlen)  
            
        num_new_tokens = input_id.shape[1] if len(ongoing_row_list) > 0 else 1
        with torch.nn.attention.sdpa_kernel(
            SDPBackend.MATH
        ):  # Actually better for Inductor to codegen attention here
            next_token = decode_some_token_function(
                model,
                input_ids=input_id,
                position_ids=position_ids,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
            )
            ongoing_input = []
        if len(ongoing_row_list) == 0:
            cur_len += 1
            ongoing_input.append(next_token[:,-1].clone())
            new_tokens.append(next_token[:,-1].clone())
            ongoing_row_list.append(0)
            row_token_num[0] += 1 
        else:
            need_remove_row = None
            cur_len += num_new_tokens
            for i in range(num_new_tokens):
                position_in_new_tokens = torch.sum(row_token_num[:(ongoing_row_list[i] + 1)], dim=0)
                new_tokens.insert(position_in_new_tokens, next_token[:,i].clone())
                ongoing_input.append(next_token[:,i].clone())
                row_token_num[ongoing_row_list[i]] += 1

                if row_token_num[ongoing_row_list[i]] == windowsize and ongoing_row_list[i] < rownum - 1:
                    ongoing_row_list.append(ongoing_row_list[i]+1)

                elif ongoing_row_list[i] == rownum - 1 and row_token_num[ongoing_row_list[i]] == columnnum:
                    row_token_num = torch.zeros((rownum,), dtype=torch.long, device="cuda")
                    ongoing_row_list = []
                    ongoing_input = [next_token[:,i]]
                    need_remove_row = None
                    break

                if row_token_num[ongoing_row_list[i]] == columnnum: ## this row is done
                    ongoing_input.pop()
                    need_remove_row = ongoing_row_list[i]

            if need_remove_row is not None:
                ongoing_row_list.remove(need_remove_row)

    with torch.nn.attention.sdpa_kernel(
            SDPBackend.MATH
        ):  # Actually better for Inductor to codegen attention here
            _ = decode_some_token_function(
                model,
                input_ids=next_token[:,-1],
                position_ids=position_ids+1,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
            )
    
    return new_tokens, position_ids+2



def vid_diagd_prepare_inputs(
    ongoing_row_list_v,
    row_token_num_v,
    ongoing_input_v,
    prompt,
    pixnum: int = 336,
    actnum: int = 11,
    rownum: int = 14,
    columnnum: int = 24,
    promptlen: int = 347,
    **kwargs
):
    new_frame = False
    position_ids = []

    for i in ongoing_row_list_v:
        global_idx = promptlen + i * columnnum + row_token_num_v[i // rownum][i % rownum] -1 + (i // rownum) * actnum
        position_ids.append(global_idx)

    lastrow = ongoing_row_list_v[-1]
    if lastrow % rownum == 0 and row_token_num_v[lastrow // rownum][lastrow % rownum] == 0:
        # WARNING
        action = kwargs["action"][lastrow // rownum]
        ongoing_input_v.append(action)
        position_ids.pop()
        pos_act = torch.arange( promptlen + (lastrow // rownum) * (pixnum+actnum) - actnum, promptlen + (lastrow // rownum) * (pixnum+actnum), device="cuda")
        position_ids.extend(pos_act.unbind())
        new_frame = True
    elif row_token_num_v[lastrow // rownum][lastrow % rownum] == 0:
        append_policy = kwargs.get("append_policy", True)
        if append_policy:
            idx_in_input_ids = (lastrow % rownum) * columnnum - 1
            ongoing_input_v.append(prompt[:, idx_in_input_ids].unsqueeze(-1))
        else:
            ongoing_input_v.append(ongoing_input_v[-1])

    input_ids = torch.cat(ongoing_input_v, dim=1)
    position_ids = torch.tensor(position_ids, device="cuda")

    return input_ids, position_ids, new_frame

def video_diagd_decode_n_tokens(
    model,
    input_ids: torch.Tensor,
    position_ids: torch.Tensor,
    num_generate_tokens: int,
    temperature: float = 1.0,
    top_p: Optional[float] = 0.8,
    top_k: Optional[int] = None,
    decode_some_token_function=decode_some_token,
    pixnum: int = 336,
    actnum: int = 11,
    columnnum: int = 24,
    rownum: int = 14,
    windowsize: int = 2,
    promptlen: int = 347,
    **kwargs,
):
    assert (
        top_p is None or top_k is None
    ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"

    cur_len = 1
    num_generate_tokens += 1
    prompt = kwargs.pop("prompt", None) 
    new_tokens = [input_ids.clone()]
    row_token_num_v = []
    ongoing_row_list_v = [0]
    row_token_num_v.append(torch.zeros((rownum,), dtype=torch.long, device="cuda"))
    row_token_num_v[0][0] += 1
    if row_token_num_v[0][0] == windowsize:
        ongoing_row_list_v.append(1)

    ongoing_input_v = [input_ids.clone()]

    while True:
        if cur_len >= num_generate_tokens:
            break


        input_id, position_ids, new_frame = vid_diagd_prepare_inputs(ongoing_row_list_v=ongoing_row_list_v, ongoing_input_v = ongoing_input_v, row_token_num_v=row_token_num_v, promptlen=promptlen, prompt=prompt, **kwargs)  
            
        num_new_tokens = input_id.shape[1]

        with torch.nn.attention.sdpa_kernel(
            SDPBackend.MATH
        ):  # Actually better for Inductor to codegen attention here
            next_token = decode_some_token_function(
                model,
                input_ids=input_id,
                position_ids=position_ids,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
            )
            ongoing_input_v = []
            if new_frame:
                next_token = torch.cat([next_token[:,:-actnum], next_token[:,-1:]], dim=1)
                num_new_tokens = num_new_tokens - actnum + 1
            
        need_remove_row = None

        cur_len += num_new_tokens
        for i in range(num_new_tokens):
            last_frame = torch.stack(row_token_num_v[:ongoing_row_list_v[i] // rownum]).sum() if ongoing_row_list_v[i] // rownum > 0 else torch.tensor(0, dtype=torch.long, device="cuda")
            position_in_new_tokens = last_frame + torch.sum(row_token_num_v[ongoing_row_list_v[i] // rownum][:(ongoing_row_list_v[i] % rownum + 1)], dim=0)
                    
            new_tokens.insert(position_in_new_tokens, next_token[:,i].clone())
            ongoing_input_v.append(next_token[:,i].clone())
            row_token_num_v[ongoing_row_list_v[i] // rownum][ongoing_row_list_v[i] % rownum] += 1

            # WARNING
            if row_token_num_v[ongoing_row_list_v[i] // rownum][ongoing_row_list_v[i] % rownum] == windowsize and ongoing_row_list_v[i] < rownum * (num_generate_tokens//pixnum) - 1:
                ongoing_row_list_v.append(ongoing_row_list_v[i]+1)
                if ongoing_row_list_v[-1] % rownum == 0:
                    row_token_num_v.append(torch.zeros((rownum,), dtype=torch.long, device="cuda"))
            if row_token_num_v[ongoing_row_list_v[i] // rownum][ongoing_row_list_v[i] % rownum] == columnnum:
                ongoing_input_v.pop()
                need_remove_row = ongoing_row_list_v[i]

        if need_remove_row is not None:
            ongoing_row_list_v.remove(need_remove_row)
    return new_tokens



================================================
FILE: inference.py
================================================
import os
import cv2
import torch
import time
import numpy as np
from tqdm import tqdm
from rich import print
from PIL import Image
from pathlib import Path
from torch import autocast
from einops import rearrange
from mcdataset import MCDataset
from omegaconf import OmegaConf
from torchvision import transforms
from argparse import ArgumentParser
from utils import load_model, tensor_to_uint8
torch.backends.cuda.matmul.allow_tf32 = False

ACCELERATE_ALGO = [
    'naive','image_diagd'
]

TARGET_SIZE=(224,384)
TOKEN_PER_IMAGE = 347 # IMAGE = PIX+ACTION
TOKEN_PER_PIX = 336

safe_globals = {"array": np.array}


def token2video(code_list, tokenizer, save_path, fps, device = 'cuda'):
    """
    change log:  we don't perform path processing inside functions to enable extensibility
    save_path: str, path to save the video, expect to endwith .mp4
    
    """
    if len(code_list) % TOKEN_PER_PIX != 0:
        print(f"code_list length {len(code_list)} is not multiple of {TOKEN_PER_PIX}")
        return
    num_images = len(code_list) // TOKEN_PER_PIX
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video = cv2.VideoWriter(save_path, fourcc, fps, (384, 224))
    for i in range(num_images):
        code = code_list[i*TOKEN_PER_PIX:(i+1)*TOKEN_PER_PIX]
        code = torch.tensor([int(x) for x in code], dtype=torch.long).to(device)
        img = tokenizer.token2image(code) # pixel
        frame = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        video.write(frame)
    video.release()

def get_args():
    parser = ArgumentParser()
    parser.add_argument('--data_root', type=str, required=True)
    parser.add_argument('--model_ckpt', type=str, required=True)
    parser.add_argument('--config', type=str, required=True)
    parser.add_argument('--output_dir', type=str, required=True)
    parser.add_argument('--demo_num', type=int, default=1)
    parser.add_argument('--frames', type=int, required=True)
    parser.add_argument('--window_size', type=int, default=2)
    parser.add_argument('--accelerate-algo', type=str, default='naive', help=f"Accelerate Algorithm Option: {ACCELERATE_ALGO}")
    parser.add_argument('--fps', type=int, default=6)
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument('--top_k', type=int, help='Use top-k sampling')
    group.add_argument('--top_p', type=float, help='Use top-p (nucleus) sampling')
    parser.add_argument('--val_data_num', type=int, default=500, help="number of validation data")
    args = parser.parse_args()
    return args


def lvm_generate(args, model, output_dir, demo_video):
    """
    """
    ### 1. set video input/output path
    input_mp4_path = os.path.join(args.data_root, demo_video)
    input_action_path = os.path.join(args.data_root, demo_video.replace('mp4','jsonl'))

    output_mp4_path = str(output_dir / demo_video)
    output_action_path = output_mp4_path.replace('.mp4', '.jsonl')
    # backup action 
    os.system(f"cp {input_action_path} {output_action_path}")
    if os.path.exists(output_mp4_path):
        print(f"output path {output_mp4_path} exist")
        return {}
    
    device = model.transformer.device
    ### 2. load action into list 
    action_list = []
    mcdataset = MCDataset()
    with open(input_action_path, 'r') as f:
        for line in f:
            line = eval(line.strip(), {"__builtins__": None}, safe_globals)
            line['camera'] = np.array(line['camera'])
            act_index = mcdataset.get_action_index_from_actiondict(line, action_vocab_offset=8192)
            action_list.append(act_index)
    ### 3. load video frames 
    cap = cv2.VideoCapture(input_mp4_path)
    start_frame = 0
    end_frame = args.demo_num
    frames = []
    for frame_idx in range(start_frame, end_frame):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        if not ret:
            print(f"Error in reading frame {frame_idx}")
            continue
        cv2.cvtColor(frame, code=cv2.COLOR_BGR2RGB, dst=frame)
        frame = np.asarray(np.clip(frame, 0, 255), dtype=np.uint8)
        frame = torch.from_numpy(frame)
        frames.append(frame)
    frames = torch.stack(frames, dim=0).to(device)
    frames = frames.permute(0, 3, 1, 2)
    frames = frames.float() / 255.0
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    frames = normalize(frames)
    
    with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
        img_index = model.tokenizer.tokenize_images(frames)
    img_index = rearrange(img_index, '(b t) h w -> b t (h w)', b=1)
     
    all_generated_tokens = []

    action_all = action_list[end_frame: end_frame + args.frames]
    action_all = torch.tensor(action_all).unsqueeze(1).to(device)
    image_input = rearrange(img_index, 'b t c -> b (t c)')
    

    start_t = time.time() 
    with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16):
        if args.accelerate_algo == 'naive':
            outputs = model.transformer.naive_generate( input_ids=image_input, max_new_tokens=TOKEN_PER_PIX*args.frames, action_all=action_all, top_k=args.top_k, top_p=args.top_p)
        elif args.accelerate_algo == 'image_diagd':
            outputs = model.transformer.img_diagd_generate(input_ids=image_input, max_new_tokens=TOKEN_PER_PIX*args.frames, action_all=action_all,windowsize = args.window_size, top_k=args.top_k, top_p=args.top_p)
        else:
            raise ValueError(f"Unknown accelerate algorithm {args.accelerate_algo}")
    end_t = time.time()
    all_generated_tokens.extend(outputs.tolist()[0])
    new_length = len(all_generated_tokens)
    time_costed = end_t - start_t 
    token_per_sec = new_length / time_costed
    frame_per_sec = token_per_sec / TOKEN_PER_PIX
    print(f"{new_length} token generated; cost {time_costed:.3f} second; {token_per_sec:.3f} token/sec {frame_per_sec:.3f} fps")
    token2video(all_generated_tokens, model.tokenizer, str(output_path / demo_video), args.fps, device)  
    # return for evaluation 
    return_item = {
        "time_costed": time_costed,
        "token_num": new_length,
    }
    return return_item
if __name__ == '__main__':
    args = get_args()
    config = OmegaConf.load(args.config)
    output_path = Path(args.output_dir)
    precision_scope = autocast
    os.makedirs(output_path, exist_ok=True)

    model = load_model(config, args.model_ckpt, gpu=True, eval_mode=True)
    print(f"[bold magenta][MINEWORLD][INFERENCE][/bold magenta] Load Model From {args.model_ckpt}")
    # get accelearte algoritm
    args.accelerate_algo = args.accelerate_algo.lower()
    if args.accelerate_algo not in ACCELERATE_ALGO:
        print(f"[bold red][Warning][/bold red] {args.accelerate_algo} is not in {ACCELERATE_ALGO}, use naive")
        args.accelerate_algo = 'naive'
    num_item = 0
    for root, _, files in os.walk(args.data_root):
        files = [f for f in files  if f.endswith('.mp4')] # mp4 would not influence progress bar 
        files = sorted(files, key=lambda x: int(x.split('_')[1].split('.')[0]))
        for file in tqdm(files):
            return_item = lvm_generate(args, model, output_path,file)
            num_item += 1
            if num_item  >= args.val_data_num:
                print(f"[bold magenta][MINEWORLD][INFERENCE][/bold magenta]  reach val data num limit {args.val_data_num}")
                break

================================================
FILE: lvm.py
================================================
"""
    Wrap the Huggingface Transformers Llama to PyTorch Lightning Module.
"""
import os
import sys
import inspect 
import torch
from typing import Optional
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers import LlamaConfig
from utils import get_obj_from_str, instantiate_from_config
from diagonal_decoding import decode_one_token, decode_some_token, decode_n_tokens, decode_n_tokens_for_gradio, prefill, img_diagd_decode_n_tokens, video_diagd_decode_n_tokens, img_diagd_decode_n_token_for_gradio
torch.backends.cuda.matmul.allow_tf32 = False

logger = logging.get_logger(__name__)
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)

if not (parentdir in sys.path):
    sys.path.insert(0, parentdir) 

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    """
    Apply rotary position embeddings to query and key tensors.

    Args:
        q (torch.Tensor): Query tensor.
        k (torch.Tensor): Key tensor.
        cos (torch.Tensor): Cosine values.
        sin (torch.Tensor): Sine values.
        position_ids (torch.Tensor): Position IDs.

    Returns:
        torch.Tensor: Query and key tensors with rotary position embeddings applied.
    """
    cos = cos[position_ids].unsqueeze(0).unsqueeze(2)
    sin = sin[position_ids].unsqueeze(0).unsqueeze(2)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class LlamaLVM(torch.nn.Module):
    def __init__(
        self,
        transformer_config,
        model_class: str,
        tokenizer_config = None,
    ):
        super().__init__()
        self.config = instantiate_from_config(transformer_config)
        self.transformer = get_obj_from_str(model_class)(self.config)
        if tokenizer_config is not None:
            self.tokenizer = instantiate_from_config(tokenizer_config)

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

class LlamaRotaryEmbedding(nn.Module):
    def __init__(
        self,
        device=None,
        config: Optional[LlamaConfig] = None,
    ):
        super().__init__()
        self.rope_kwargs = {}
        self.rope_type = "default"
        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
        self.max_position_embeddings = config.max_position_embeddings
        inv_freq, _ = self.rope_init_fn(self.config, device, **self.rope_kwargs)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._set_cos_sin_cache(
            device=self.inv_freq.device,
            dtype=torch.get_default_dtype(),
        )
    
    def _set_cos_sin_cache(self, device, dtype):
        """
        Set the cosine and sine cache for positional embeddings.

        Args:
            seq_len (int): The sequence length.
            device (str): The device on which the cache tensors will be stored.
            dtype: The data type of the cache tensors.
        """
        t = torch.arange(
            self.max_position_embeddings, device=device, dtype=self.inv_freq.dtype
        )

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer(
            "cos_cached", emb.cos().to(dtype), persistent=False
        )
        self.register_buffer(
            "sin_cached", emb.sin().to(dtype), persistent=False
        )
    
    def forward(self, x, seq_len=None):
        """
        Forward pass of the LlamaRotaryEmbedding module.

        Args:
            x (torch.Tensor): Input tensor of shape [bs, num_attention_heads, seq_len, head_size].
            seq_len (int): The sequence length. If greater than the cached length, the cache will be updated.

        Returns:
            tuple: A tuple containing two tensors, the cosine and sine embeddings, both of shape [1, 1, seq_len, dim].
        """
        if seq_len > self.max_position_embeddings:
            raise ValueError("seq length should less than max embedding")

        return (
            self.cos_cached[:seq_len, :].to(dtype=x.dtype),
            self.sin_cached[:seq_len, :].to(dtype=x.dtype),
        )

class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj
  
class LlamaAttention(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        assert (self.head_dim * self.num_heads) == self.hidden_size, "hidden_size must be divisible by num_heads"
        self.q_proj = nn.Linear(
            self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
        )
        self.k_proj = nn.Linear(
            self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.v_proj = nn.Linear(
            self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.o_proj = nn.Linear(
            self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
        )
        self.rotary_emb = LlamaRotaryEmbedding(config=config)
        self.max_batch_size = getattr(config, "max_batch_size", 1)
        self.init_kv_cache()

    def init_kv_cache(self, dtype=torch.float16):
        cache_shape = (self.max_batch_size, self.max_position_embeddings, self.num_key_value_heads, self.head_dim)
        self.cache_k = torch.zeros(cache_shape, dtype=dtype).cuda()
        self.cache_v = torch.zeros(cache_shape, dtype=dtype).cuda()

    def forward(
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            positions_embedding = None,
    ):
        
        bsz, q_len, _ = hidden_states.size()
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(
            bsz, q_len, self.num_heads, self.head_dim
        )
        key_states = key_states.view(
            bsz, q_len, self.num_key_value_heads, self.head_dim
        )
        value_states = value_states.view(
            bsz, q_len, self.num_key_value_heads, self.head_dim
        )

        cos, sin = positions_embedding


        query_states, key_states = apply_rotary_pos_emb(
            query_states, key_states, cos, sin, position_ids
        )
        
        self.cache_k[:bsz, position_ids] = key_states
        self.cache_v[:bsz, position_ids] = value_states
        key_states, value_states = (
                self.cache_k[:bsz, :, :],
                self.cache_v[:bsz, :, :],
            )

        key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=2)
        value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=2)

        query_states, key_states, value_states = map(lambda x: x.transpose(1, 2), (query_states, key_states, value_states))
        
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=attention_mask,
            dropout_p=0.0,
            is_causal=False,
        ).transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)
        return attn_output

class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = LlamaAttention(config=config)
        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            positions_embedding = None,
    ):
        """
        Forward pass for the LlamaDecoderLayer.

        Args:
            hidden_states (torch.FloatTensor): Input tensor of shape `(batch, seq_len, embed_dim)`.
            attention_mask (torch.FloatTensor, optional): Attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            position_ids (torch.LongTensor, optional): Positional IDs tensor.


        Returns:
            Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: Tuple containing:
                - hidden_states (torch.FloatTensor): Output tensor.
        """

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            positions_embedding=positions_embedding,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states

class LlamaModel(PreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = LlamaRotaryEmbedding(config=config)
        self.max_position_embedding = config.max_position_embeddings
        self.causal_mask = torch.tril(
            torch.ones(self.max_position_embedding, self.max_position_embedding, dtype=torch.bool)
        ).cuda()
        self.post_init()
        
    def _create_attention_mask(self, input_pos: Optional[torch.Tensor]):
        """
        Creates an attention mask for the transformer layers.

        Args:
            input_pos[torch.Tensor]: The position of input sequence (used for inference only).

        Returns:
            Optional[torch.Tensor]: The attention mask, or None for causal mask.
        """
        mask = self.causal_mask[input_pos]
        return mask

    def forward(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
    ):

        if input_ids is None:
            raise ValueError(
                "decoder_input_ids is None"
            )
        hidden_states = self.embed_tokens(input_ids)
        
        positions_embedding = self.rotary_emb(hidden_states, seq_len=self.max_position_embedding)

        attention_mask = self._create_attention_mask(input_pos=position_ids)
        for idx, decoder_layer in enumerate(self.layers):
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                positions_embedding=positions_embedding,
            )

            hidden_states = layer_outputs

        hidden_states = self.norm(hidden_states)

        return hidden_states
        
class LlamaForCausalLM(PreTrainedModel):

    def __init__(self, config):
        super().__init__(config)
        self.model = LlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.post_init()      

    def forward(
            self,
            input_ids: torch.LongTensor = None,
            position_ids: Optional[torch.LongTensor] = None,
    ):

        outputs = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
        )
        logits = self.lm_head(outputs[:, :, :])
        return logits
    def refresh_kvcache(self):
        for i in self.model.layers:
            i.self_attn.init_kv_cache()

    def naive_generate(self, input_ids, max_new_tokens, temperature=1.0, action_all=None, top_p=None, top_k=None):

        self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
        if action_all is not None:
            input_ids = torch.cat([input_ids, action_all[0]], dim=-1)
        position_ids = torch.arange(0, input_ids.shape[1], device="cuda")
        next_token = self.prefill(
            self,
            input_ids=input_ids,
            position_ids=position_ids,
            temperature=temperature,
            top_k = top_k,
            top_p = top_p,
        )

        self.decode_one_token = torch.compile(decode_one_token, mode="max-autotune", fullgraph=True)
        position_ids = torch.tensor([input_ids.shape[1]], dtype=torch.long, device="cuda")
        
        generated_tokens = decode_n_tokens(
            self,
            input_ids = next_token.view(1, -1),
            position_ids = position_ids,
            num_generate_tokens = max_new_tokens - 1,
            temperature = temperature,
            decode_one_token_function=self.decode_one_token,
            action=action_all,
            top_p = top_p,
            top_k = top_k,
        )
        return torch.cat(generated_tokens, dim=1)
    
    def prefill_for_gradio(self, input_ids, temperature=1.0):
        self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
        last_pos = input_ids.shape[1]
        position_ids = torch.arange(0, last_pos, device="cuda")
        with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16):
            next_token = self.prefill(
                self,
                input_ids=input_ids,
                position_ids=position_ids,
                temperature=temperature,
            )
        return next_token, last_pos
    
    def decode_img_token_for_gradio(self, input_action, position_id, max_new_tokens, temperature=1.0):
        self.decode_one_token = torch.compile(decode_one_token, mode="max-autotune", fullgraph=True)
        # self.decode_one_token = decode_one_token
        # WARNING
        position_ids = torch.arange(position_id, position_id + input_action.shape[1], device="cuda")
        with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16):
            generated_tokens, position_id = decode_n_tokens_for_gradio(
                self,
                input_ids = input_action,
                position_ids = position_ids,
                num_generate_tokens = max_new_tokens,
                temperature = temperature,
                decode_one_token_function=self.decode_one_token,
            )
        # WARNING
        return generated_tokens, position_id
    
    def diagd_img_token_for_gradio(self, input_action, position_id, max_new_tokens, temperature=1.0, windowsize=2):
        self.decode_some_token = torch.compile(decode_some_token, mode="max-autotune", fullgraph=True)
        position_ids = torch.arange(position_id, position_id + input_action.shape[1], device="cuda")
        with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16):
            generated_tokens, position_id = img_diagd_decode_n_token_for_gradio(
                self,
                input_ids = input_action,
                position_ids = position_ids,
                num_generate_tokens = max_new_tokens,
                temperature = temperature,
                decode_some_token_function=self.decode_some_token,
                windowsize = windowsize,
            )
        return generated_tokens, position_id


    def img_diagd_generate(self, input_ids, max_new_tokens, temperature=1.0, action_all=None, windowsize=2, top_p=None, top_k=None):

        self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
        input_ids = torch.cat([input_ids, action_all[0]], dim=-1)
        position_ids = torch.arange(0, input_ids.shape[1], device="cuda")
        next_token = self.prefill(
            self,
            input_ids=input_ids,
            position_ids=position_ids,
            temperature=temperature,
            top_k = top_k,
            top_p = top_p,
        )

        self.decode_some_token = torch.compile(decode_some_token, mode="max-autotune", fullgraph=True)
        position_ids = torch.tensor([input_ids.shape[1]], dtype=torch.long, device="cuda")

        generated_tokens = img_diagd_decode_n_tokens(
            self,
            input_ids = next_token.view(1, -1),
            position_ids = position_ids,
            num_generate_tokens = max_new_tokens - 1,
            temperature = temperature,
            decode_some_token_function=self.decode_some_token,
            windowsize = windowsize,
            action=action_all,
            prompt=input_ids,
            top_k = top_k,
            top_p = top_p,
        )
        return torch.cat(generated_tokens, dim=1)
    
    def vid_diagd_generate(self, input_ids, max_new_tokens,windowsize=2, temperature=1.0, action_all=None,**kwargs):

        self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
        input_ids = torch.cat([input_ids, action_all[0]], dim=-1)
        position_ids = torch.arange(0, input_ids.shape[1], device="cuda")
        next_token = self.prefill(
            self,
            input_ids=input_ids,
            position_ids=position_ids,
            temperature=temperature,
        )

        self.decode_some_token = torch.compile(decode_some_token, mode="max-autotune", fullgraph=True)
        # self.decode_some_token = decode_some_token
        position_ids = torch.tensor([input_ids.shape[1]], dtype=torch.long, device="cuda")

        generated_tokens = video_diagd_decode_n_tokens(
            self,
            input_ids = next_token.view(1, -1),
            position_ids = position_ids,
            num_generate_tokens = max_new_tokens - 1,
            temperature = temperature,
            decode_some_token_function=self.decode_some_token,
            windowsize = windowsize,
            action=action_all,
            prompt=input_ids,
            **kwargs
        )
        return torch.cat(generated_tokens, dim=1)




================================================
FILE: mcdataset.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import json
import attr
import collections
import numpy as np
from typing import Union, Dict
import torch
from utils import print0


# https://github.com/openai/Video-Pre-Training/blob/aed46b90e8db2332801feabd8be2de01f92c0ad2/run_inverse_dynamics_model.py#L17
KEYBOARD_BUTTON_MAPPING = {
    "key.keyboard.escape" :"ESC",
    "key.keyboard.s" :"back",
    "key.keyboard.q" :"drop",
    "key.keyboard.w" :"forward",
    "key.keyboard.1" :"hotbar.1",
    "key.keyboard.2" :"hotbar.2",
    "key.keyboard.3" :"hotbar.3",
    "key.keyboard.4" :"hotbar.4",
    "key.keyboard.5" :"hotbar.5",
    "key.keyboard.6" :"hotbar.6",
    "key.keyboard.7" :"hotbar.7",
    "key.keyboard.8" :"hotbar.8",
    "key.keyboard.9" :"hotbar.9",
    "key.keyboard.e" :"inventory",
    "key.keyboard.space" :"jump",
    "key.keyboard.a" :"left",
    "key.keyboard.d" :"right",
    "key.keyboard.left.shift" :"sneak",
    "key.keyboard.left.control" :"sprint",
    "key.keyboard.f" :"swapHands",
}

# https://github.com/openai/Video-Pre-Training/blob/aed46b90e8db2332801feabd8be2de01f92c0ad2/run_inverse_dynamics_model.py#L41
# Template action
NOOP_ACTION = {
    "ESC": 0,
    "back": 0,
    "drop": 0,
    "forward": 0,
    "hotbar.1": 0,
    "hotbar.2": 0,
    "hotbar.3": 0,
    "hotbar.4": 0,
    "hotbar.5": 0,
    "hotbar.6": 0,
    "hotbar.7": 0,
    "hotbar.8": 0,
    "hotbar.9": 0,
    "inventory": 0,
    "jump": 0,
    "left": 0,
    "right": 0,
    "sneak": 0,
    "sprint": 0,
    "swapHands": 0,
    "camera": np.array([0, 0]),  # [y, x]
    "attack": 0,
    "use": 0,
    "pickItem": 0,
}

OASIS_ACTION_KEYS = [
    "inventory",
    "ESC",
    "hotbar.1",
    "hotbar.2",
    "hotbar.3",
    "hotbar.4",
    "hotbar.5",
    "hotbar.6",
    "hotbar.7",
    "hotbar.8",
    "hotbar.9",
    "forward",
    "back",
    "left",
    "right",
    "cameraX",
    "cameraY",
    "jump",
    "sneak",
    "sprint",
    "swapHands",
    "attack",
    "use",
    "pickItem",
    "drop",
]


# Matches a number in the MineRL Java code regarding sensitivity
# This is for mapping from recorded sensitivity to the one used in the model
CAMERA_SCALER = 360.0 / 2400.0


# https://github.com/openai/Video-Pre-Training/blob/main/lib/actions.py#L8 with some modifications
class Buttons:
    # 14 in total without hotbar and camera
    ATTACK = "attack"
    BACK = "back"
    FORWARD = "forward"
    JUMP = "jump"
    LEFT = "left"
    RIGHT = "right"
    SNEAK = "sneak"
    SPRINT = "sprint"
    USE = "use"
    DROP = "drop"
    INVENTORY = "inventory"
    # added by Yang
    ESC = "ESC"
    SWAPHANDS = "swapHands"
    PICKITEM = "pickItem"

    ALL = [
        USE,
        ATTACK,

        
        FORWARD,
        BACK,
        LEFT,
        RIGHT,

        JUMP,
        SNEAK,
        SPRINT,
        
        DROP,
        SWAPHANDS,
        PICKITEM,

        INVENTORY,
        ESC,
    ] + [f"hotbar.{i}" for i in range(1, 10)]


class QuantizationScheme:
    LINEAR = "linear"
    MU_LAW = "mu_law"


# https://github.com/openai/Video-Pre-Training/blob/main/lib/actions.py#L49
@attr.s(auto_attribs=True)
class CameraQuantizer:
    """
    A camera quantizer that discretizes and undiscretizes a continuous camera input with y (pitch) and x (yaw) components.

    Parameters:
    - camera_binsize: The size of the bins used for quantization. In case of mu-law quantization, it corresponds to the average binsize.
    - camera_maxval: The maximum value of the camera action.
    - quantization_scheme: The quantization scheme to use. Currently, two quantization schemes are supported:
    - Linear quantization (default): Camera actions are split uniformly into discrete bins
    - Mu-law quantization: Transforms the camera action using mu-law encoding (https://en.wikipedia.org/wiki/%CE%9C-law_algorithm)
    followed by the same quantization scheme used by the linear scheme.
    - mu: Mu is the parameter that defines the curvature of the mu-law encoding. Higher values of
    mu will result in a sharper transition near zero. Below are some reference values listed
    for choosing mu given a constant maxval and a desired max_precision value.
    maxval = 10 | max_precision = 0.5  | μ ≈ 2.93826
    maxval = 10 | max_precision = 0.4  | μ ≈ 4.80939
    maxval = 10 | max_precision = 0.25 | μ ≈ 11.4887
    maxval = 20 | max_precision = 0.5  | μ ≈ 2.7
    maxval = 20 | max_precision = 0.4  | μ ≈ 4.39768
    maxval = 20 | max_precision = 0.25 | μ ≈ 10.3194
    maxval = 40 | max_precision = 0.5  | μ ≈ 2.60780
    maxval = 40 | max_precision = 0.4  | μ ≈ 4.21554
    maxval = 40 | max_precision = 0.25 | μ ≈ 9.81152
    """

    camera_maxval: int
    camera_binsize: int
    quantization_scheme: str = attr.ib(
        default=QuantizationScheme.LINEAR,
        validator=attr.validators.in_([QuantizationScheme.LINEAR, QuantizationScheme.MU_LAW]),
    )
    mu: float = attr.ib(default=5)

    def discretize(self, xy):
        xy = np.clip(xy, -self.camera_maxval, self.camera_maxval)

        if self.quantization_scheme == QuantizationScheme.MU_LAW:
            xy = xy / self.camera_maxval
            v_encode = np.sign(xy) * (np.log(1.0 + self.mu * np.abs(xy)) / np.log(1.0 + self.mu))
            v_encode *= self.camera_maxval
            xy = v_encode

        # Quantize using linear scheme
        return np.round((xy + self.camera_maxval) / self.camera_binsize).astype(np.int64)

    def undiscretize(self, xy):
        xy = xy * self.camera_binsize - self.camera_maxval

        if self.quantization_scheme == QuantizationScheme.MU_LAW:
            xy = xy / self.camera_maxval
            v_decode = np.sign(xy) * (1.0 / self.mu) * ((1.0 + self.mu) ** np.abs(xy) - 1.0)
            v_decode *= self.camera_maxval
            xy = v_decode
        return xy


class MCDataset(torch.utils.data.Dataset):
    """
    Dataset for Minecraft.
    """
    def __init__(self,
                 action_length: int = 11,  # including bos and eos
                 camera_binsize: int = 9,  # 2 in vpt
                 camera_maxval: int = 90,  # 10 in vpt
                 camera_mu: float = 11.4887,  # 10 in vpt
                 quantization_scheme: str = "mu_law",
    ):
        self.action_length = action_length
        self.camera_quantizer = CameraQuantizer(
            camera_binsize=camera_binsize,
            camera_maxval=camera_maxval,
            mu=camera_mu,
            quantization_scheme=quantization_scheme,
        )

    def json_action_to_env_action(self, json_action):
        """
        https://github.com/openai/Video-Pre-Training/blob/aed46b90e8db2332801feabd8be2de01f92c0ad2/run_inverse_dynamics_model.py#L80
        Converts a json action into a MineRL action.
        Returns (minerl_action, is_null_action)
        """
        # This might be slow...
        env_action = NOOP_ACTION.copy()
        # As a safeguard, make camera action again so we do not override anything
        env_action["camera"] = np.array([0, 0])

        is_null_action = True
        keyboard_keys = json_action["keyboard"]["keys"]
        for key in keyboard_keys:
            # You can have keys that we do not use, so just skip them
            # NOTE in original training code, ESC was removed and replaced with
            #      "inventory" action if GUI was open.
            #      Not doing it here, as BASALT uses ESC to quit the game.
            if key in KEYBOARD_BUTTON_MAPPING:
                env_action[KEYBOARD_BUTTON_MAPPING[key]] = 1
                is_null_action = False

        mouse = json_action["mouse"]
        camera_action = env_action["camera"]
        camera_action[0] = mouse["dy"] * CAMERA_SCALER
        camera_action[1] = mouse["dx"] * CAMERA_SCALER

        if mouse["dx"] != 0 or mouse["dy"] != 0:
            is_null_action = False
        else:
            if abs(camera_action[0]) > 180:
                camera_action[0] = 0
            if abs(camera_action[1]) > 180:
                camera_action[1] = 0

        mouse_buttons = mouse["buttons"]
        if 0 in mouse_buttons:
            env_action["attack"] = 1
            is_null_action = False
        if 1 in mouse_buttons:
            env_action["use"] = 1
            is_null_action = False
        if 2 in mouse_buttons:
            env_action["pickItem"] = 1
            is_null_action = False

        # added by Yang
        # if two confictory actions are pressed, remove them
        if env_action["forward"] == 1 and env_action["back"] == 1:
            env_action["forward"] = 0
            env_action["back"] = 0
        if env_action["left"] == 1 and env_action["right"] == 1:
            env_action["left"] = 0
            env_action["right"] = 0 
        if env_action["jump"] == 1 and env_action["sneak"] == 1:
            env_action["jump"] = 0
            env_action["sneak"] = 0
        if env_action["sprint"] == 1 and env_action["sneak"] == 1:
            env_action["sprint"] = 0
            env_action["sneak"] = 0
        if env_action["attack"] == 1 and env_action["use"] == 1:
            env_action["attack"] = 0
            env_action["use"] = 0

        # remove inventory and ESC action
        if env_action["inventory"] == 1:
            is_null_action = True
        if env_action["ESC"] == 1:
            is_null_action = True

        return env_action, is_null_action

    def make_action_vocab(self,
                          num_cam_bins: int = 21,
                          action_vocab_offset: int = 0,
                          verbose: bool = False):
        action_vocab = collections.OrderedDict()
        # 14 actions and hotbar.1-9
        for i, action in enumerate(Buttons.ALL):
            action_vocab[action] = i
        # camera 0 
        for i in range(num_cam_bins):
            action_vocab[f"cam_0_{i}"] = len(Buttons.ALL) + i
        # camera 1
        for i in range(num_cam_bins):
            action_vocab[f"cam_1_{i}"] = len(Buttons.ALL) + num_cam_bins + i
        # bos, null, eos
        action_vocab["<act_bos>"] = len(Buttons.ALL) + 2 * num_cam_bins
        action_vocab["<null_act>"] = len(Buttons.ALL) + 2 * num_cam_bins + 1
        action_vocab["<act_eos>"] = len(Buttons.ALL) + 2 * num_cam_bins + 2

        if action_vocab_offset > 0:
            action_vocab = {k: v + action_vocab_offset for k, v in action_vocab.items()}

        if verbose:
            print0(f"[bold yellow]\[MCDataset][/bold yellow] Action Vocab: {action_vocab}")

        self.action_vocab = action_vocab
        # return action_vocab

    def _handle_conflict_action_index(self,
                                action_dict: Dict[str, Union[int, np.ndarray]],
                                key1: str,
                                key2: str,
                                null_key: str,
                                verbose: bool = False):
        if action_dict[key1] == 1 and action_dict[key2] == 1:
            if verbose:
                print0(f"[bold yellow]\[MCDataset][/bold yellow] {key1} and {key2} are both pressed")
            return self.action_vocab[null_key]
        elif action_dict[key1] == 1:
            return self.action_vocab[key1]
        elif action_dict[key2] == 1:
            return self.action_vocab[key2]
        else:
            return self.action_vocab[null_key]

    def get_action_index_from_actiondict(self,
                                         action_dict: Dict[str, Union[int, np.ndarray]],
                                         action_vocab_offset: int = 0,
                                         verbose: bool = False):

        if not hasattr(self, "action_vocab"):
            self.make_action_vocab(action_vocab_offset=action_vocab_offset, verbose=verbose)

        # action_list = [boa, camy, camx, hotbar, fore_back, left_right, sprint_sneak, use_attack, jump, drop_pick, eoa]
        # 11 actions
        action_list = [self.action_vocab["<null_act>"]] * self.action_length
        # 0 & 10
        action_list[0] = self.action_vocab["<act_bos>"]
        action_list[-1] = self.action_vocab["<act_eos>"]

        camera_action = action_dict["camera"]
        assert len(camera_action) == 2, f"[MCDataset] camera_action length is not 2: {camera_action}"
        # camera_action should be numpy array
        if not isinstance(camera_action, np.ndarray):
            camera_action = np.array(camera_action)
        camera_action = self.camera_quantizer.discretize(camera_action)
        # 1 & 2
        action_list[1] = self.action_vocab[f"cam_0_{camera_action[0]}"]
        action_list[2] = self.action_vocab[f"cam_1_{camera_action[1]}"]

        # 3
        for i in range(1, 10):
            if f"hotbar.{i}" in action_dict and action_dict[f"hotbar.{i}"] == 1:
                action_list[3] = self.action_vocab[f"hotbar.{i}"]
                break

        # 4 forward/back
        action_list[4] = self._handle_conflict_action_index(action_dict, "forward", "back", "<null_act>", verbose=verbose)
        # 5 left/right
        action_list[5] = self._handle_conflict_action_index(action_dict, "left", "right", "<null_act>", verbose=verbose)
        # 6 sprint/sneak
        action_list[6] = self._handle_conflict_action_index(action_dict, "sprint", "sneak", "<null_act>", verbose=verbose)
        # 7 use/attack
        action_list[7] = self._handle_conflict_action_index(action_dict, "use", "attack", "<null_act>", verbose=verbose)
        # 8 jump
        action_list[8] = self.action_vocab["jump"] if action_dict["jump"] == 1 else self.action_vocab["<null_act>"]
        # 9 drop/pick
        action_list[9] = self._handle_conflict_action_index(action_dict, "drop", "pickItem", "<null_act>", verbose=verbose)

        if verbose:
            print0(f"[bold yellow]\[MCDataset][/bold yellow] Action List: {action_list}")

        return action_list

    def read_jsonl(self, jsonl_path: str):
        assert os.path.isfile(jsonl_path), f"[MCDataset] {jsonl_path} does not exist"
        # read jsonl
        # https://github.com/openai/Video-Pre-Training/blob/main/data_loader.py#L76
        try:
            with open(jsonl_path) as json_file:
                json_lines = json_file.readlines()
                json_data = "[" + ",".join(json_lines) + "]"
                json_data = json.loads(json_data)
        except Exception as e:
            print0(f"[bold yellow]\[MCDataset][/bold yellow] {jsonl_path} cannot be read: {e}")
            return None
        return json_data

================================================
FILE: metrics/IDM/inverse_dynamics_model.py
================================================
# Borrowed from VPT (https://github.com/openai/Video-Pre-Training)

import numpy as np
import torch as th
import cv2
from gym3.types import DictType
from gym import spaces
from tqdm import tqdm
import os 
from argparse import ArgumentParser
import pickle
import cv2
import json
from lib.action_mapping import IDMActionMapping
from lib.actions import ActionTransformer
from lib.policy import InverseActionPolicy
from lib.torch_util import default_device_type, set_default_torch_device

from sklearn.metrics import precision_score, recall_score, f1_score
# Hardcoded settings
AGENT_RESOLUTION = (128, 128)
safe_globals = {"array": np.array}
def resize_image(img, target_resolution):
    # For your sanity, do not resize with any function than INTER_LINEAR
    img = cv2.resize(img, target_resolution, interpolation=cv2.INTER_LINEAR)
    return img


ACTION_TRANSFORMER_KWARGS = dict(
    camera_binsize=2,
    camera_maxval=10,
    camera_mu=10,
    camera_quantization_scheme="mu_law",
)

class IDMAgent:
    """
    Sugarcoating on the inverse dynamics model (IDM) used to predict actions Minecraft players take in videos.

    Functionally same as MineRLAgent.
    """
    def __init__(self, idm_net_kwargs, pi_head_kwargs, device=None):
        if device is None:
            device = default_device_type()
        self.device = th.device(device)
        # Set the default torch device for underlying code as well
        set_default_torch_device(self.device)
        self.action_mapper = IDMActionMapping(n_camera_bins=11)
        action_space = self.action_mapper.get_action_space_update()
        action_space = DictType(**action_space)

        self.action_transformer = ActionTransformer(**ACTION_TRANSFORMER_KWARGS)

        idm_policy_kwargs = dict(idm_net_kwargs=idm_net_kwargs, pi_head_kwargs=pi_head_kwargs, action_space=action_space)

        self.policy = InverseActionPolicy(**idm_policy_kwargs).to(device)
        self.hidden_state = self.policy.initial_state(1)
        self._dummy_first = th.from_numpy(np.array((False,))).to(device)

    def load_weights(self, path):
        """Load model weights from a path, and reset hidden state"""
        self.policy.load_state_dict(th.load(path, map_location=self.device), strict=False)
        self.reset()

    def reset(self):
        """Reset agent to initial state (i.e., reset hidden state)"""
        self.hidden_state = self.policy.initial_state(1)

    def _video_obs_to_agent(self, video_frames):
        imgs = [resize_image(frame, AGENT_RESOLUTION) for frame in video_frames]
        # Add time and batch dim
        imgs = np.stack(imgs)[None]
        agent_input = {"img": th.from_numpy(imgs).to(self.device)}
        return agent_input

    def _agent_action_to_env(self, agent_action):
        """Turn output from policy into action for MineRL"""
        # This is quite important step (for some reason).
        # For the sake of your sanity, remember to do this step (manual conversion to numpy)
        # before proceeding. Otherwise, your agent might be a little derp.
        action = {
            "buttons": agent_action["buttons"].cpu().numpy(),
            "camera": agent_action["camera"].cpu().numpy()
        }
        minerl_action = self.action_mapper.to_factored(action)
        minerl_action_transformed = self.action_transformer.policy2env(minerl_action)
        return minerl_action_transformed

    def predict_actions(self, video_frames):
        """
        Predict actions for a sequence of frames.

        `video_frames` should be of shape (N, H, W, C).
        Returns MineRL action dict, where each action head
        has shape (N, ...).

        Agent's hidden state is tracked internally. To reset it,
        call `reset()`.
        """
        agent_input = self._video_obs_to_agent(video_frames)
        # The "first" argument could be used to reset tell episode
        # boundaries, but we are only using this for predicting (for now),
        # so we do not hassle with it yet.
        dummy_first = th.zeros((video_frames.shape[0], 1)).to(self.device)
        predicted_actions, self.hidden_state, _ = self.policy.predict(
            agent_input, first=dummy_first, state_in=self.hidden_state,
            deterministic=True
        )
        predicted_minerl_action = self._agent_action_to_env(predicted_actions)
        return predicted_minerl_action
# NOTE: this is _not_ the original code of IDM!
# As such, while it is close and seems to function well,
# its performance might be bit off from what is reported
# in the paper.

ENV_KWARGS = dict(
    fov_range=[70, 70],
    frameskip=1,
    gamma_range=[2, 2],
    guiscale_range=[1, 1],
    resolution=[640, 360],
    cursor_size_range=[16.0, 16.0],
)


KEYBOARD_BUTTON_MAPPING = {
    "key.keyboard.escape" :"ESC",
    "key.keyboard.s" :"back",
    "key.keyboard.q" :"drop",
    "key.keyboard.w" :"forward",
    "key.keyboard.1" :"hotbar.1",
    "key.keyboard.2" :"hotbar.2",
    "key.keyboard.3" :"hotbar.3",
    "key.keyboard.4" :"hotbar.4",
    "key.keyboard.5" :"hotbar.5",
    "key.keyboard.6" :"hotbar.6",
    "key.keyboard.7" :"hotbar.7",
    "key.keyboard.8" :"hotbar.8",
    "key.keyboard.9" :"hotbar.9",
    "key.keyboard.e" :"inventory",
    "key.keyboard.space" :"jump",
    "key.keyboard.a" :"left",
    "key.keyboard.d" :"right",
    "key.keyboard.left.shift" :"sneak",
    "key.keyboard.left.control" :"sprint",
    "key.keyboard.f" :"swapHands",
}

# Template action
NOOP_ACTION = {
    "ESC": 0,
    "back": 0,
    "drop": 0,
    "forward": 0,
    "hotbar.1": 0,
    "hotbar.2": 0,
    "hotbar.3": 0,
    "hotbar.4": 0,
    "hotbar.5": 0,
    "hotbar.6": 0,
    "hotbar.7": 0,
    "hotbar.8": 0,
    "hotbar.9": 0,
    "inventory": 0,
    "jump": 0,
    "left": 0,
    "right": 0,
    "sneak": 0,
    "sprint": 0,
    "swapHands": 0,
    "camera": np.array([0, 0]),
    "attack": 0,
    "use": 0,
    "pickItem": 0,
}


# Matches a number in the MineRL Java code regarding sensitivity
# This is for mapping from recorded sensitivity to the one used in the model
CAMERA_SCALER = 360.0 / 2400.0


def json_action_to_env_action(json_action):
    """
    Converts a json action into a MineRL action.
    Returns (minerl_action, is_null_action)
    """
    if "ESC" in json_action:
        return json_action, False
    # This might be slow...
    env_action = NOOP_ACTION.copy()
    # As a safeguard, make camera action again so we do not override anything
    env_action["camera"] = np.array([0, 0])

    is_null_action = True
    keyboard_keys = json_action["keyboard"]["keys"]
    for key in keyboard_keys:
        # You can have keys that we do not use, so just skip them
        # NOTE in original training code, ESC was removed and replaced with
        #      "inventory" action if GUI was open.
        #      Not doing it here, as BASALT uses ESC to quit the game.
        if key in KEYBOARD_BUTTON_MAPPING:
            env_action[KEYBOARD_BUTTON_MAPPING[key]] = 1
            is_null_action = False

    mouse = json_action["mouse"]
    camera_action = env_action["camera"]
    camera_action[0] = mouse["dy"] * CAMERA_SCALER
    camera_action[1] = mouse["dx"] * CAMERA_SCALER

    if mouse["dx"] != 0 or mouse["dy"] != 0:
        is_null_action = False
    else:
        if abs(camera_action[0]) > 180:
            camera_action[0] = 0
        if abs(camera_action[1]) > 180:
            camera_action[1] = 0

    mouse_buttons = mouse["buttons"]
    if 0 in mouse_buttons:
        env_action["attack"] = 1
        is_null_action = False
    if 1 in mouse_buttons:
        env_action["use"] = 1
        is_null_action = False
    if 2 in mouse_buttons:
        env_action["pickItem"] = 1
        is_null_action = False

    return env_action, is_null_action


def load_action_jsonl(json_path):
    with open(json_path) as json_file:
        json_lines = json_file.readlines()
        json_data = "[" + ",".join(json_lines) + "]"
        json_data = json.loads(json_data)
    return json_data 


# loss on frame - avg on video - avg on dataset 
def evaluate_IDM_quality(model, weights,jsonl_folder, video_folder, infer_demo_num, n_frames, output_file):
    """
    Evaluate the quality of a IDM model on a dataset of videos.

    Args:
        video_folder (str): Path to the folder containing videos.
        model (str): Path to the '.model' file to be loaded.
        weights (str): Path to the '.weights' file to be loaded.
        n_batches (int): Number of batches to process.
        n_frames (int): Number of frames to process at a time.
    """
    ## set up IDM model 
    agent_parameters = pickle.load(open(model, "rb"))
    net_kwargs = agent_parameters["model"]["args"]["net"]["args"]
    pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"]
    pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"])
    # pi_head_kwargs["temperature"] = 1.0
    agent = IDMAgent(idm_net_kwargs=net_kwargs, pi_head_kwargs=pi_head_kwargs)
    agent.load_weights(weights)
    # Load video files
    video_files = os.listdir(video_folder)
    video_files = [f for f in video_files if f.endswith(".mp4")]
    video_files = sorted(video_files)
    video_files = [os.path.join(video_folder, f) for f in video_files]
    eval_num = min(500,len(video_files)) 
    video_files = video_files[:eval_num]
    dataset_labels = {}
    camera_loss_list = []
    for video_file in tqdm(video_files):
        json_file = os.path.join(jsonl_folder,os.path.basename(video_file).replace(".mp4",".jsonl"))
        # old implementation
        # action_loss,video_avg_loss,predicted_actions_list = eval_1_video(agent, video_file, json_file, infer_demo_num, n_frames) 
        
        # load predicted actions and recorded actions
        predicted_actions,recorded_actions = idm_prediction(agent, video_file,json_file, infer_demo_num, n_frames)
        # construct labels
        subtasks_labels = define_exclusive_classification_task(predicted_actions,recorded_actions,calculate_hot_bar = False)
        for key in subtasks_labels:
            if key not in dataset_labels:
                dataset_labels[key] = {"pred_labels":[] , "rec_labels":[], "class_num":0}
            dataset_labels[key]["pred_labels"].append(subtasks_labels[key]["pred_labels"])# array 
            dataset_labels[key]["rec_labels"].append(subtasks_labels[key]["rec_labels"]) # array 
            dataset_labels[key]["class_num"] = subtasks_labels[key]["class_num"]
        camera_loss_list.append(camera_loss(predicted_actions,recorded_actions)["camera_bin_loss"])
        
    dataset_results ={}
    for key in dataset_labels:
        pred_labels = np.stack(dataset_labels[key]["pred_labels"]).flatten() # [num_videos , num_frames] -> [video_num x frame_num]
        rec_labels = np.stack(dataset_labels[key]["rec_labels"]).flatten()   # [num_videos , num_frames] -> [video_num x frame_num]
        dataset_results[key]=classification_metric(pred_labels, rec_labels, dataset_labels[key]["class_num"])
    
    # import pdb;pdb.set_trace()
    metric_mean_on_task = {}
    metrics = ['precision_micro', 'recall_micro', 'f1_micro', 'precision_macro', 'recall_macro', 'f1_macro']
    tasks = dataset_results.keys()
    for key in metrics:
        if key == "class_num":
            continue
        metric_mean_on_task[key] = np.mean([dataset_results[task][key] for task in tasks])
    dataset_results["metric_mean_on_task"] = metric_mean_on_task
    dataset_results["metric_mean_on_task"]["camera_loss"] = np.mean(camera_loss_list)
    ## change all keys into str
    dataset_results = {str(k): v for k, v in dataset_results.items()}
        
    print(dataset_results)
    print("===========================================")
    print(f"{output_file} IDM Metric: {metric_mean_on_task}")
    with open(output_file, 'w') as f:
        f.write(json.dumps(dataset_results,indent=4) + "\n")

def construct_classification_labels(idm_actions:dict[str, list[int]],action_name_keys: list[int],num_class:int) -> list[int]: 
    """
    convert original predicted actions to classification labels
    """
    # construct a one-hot vector string to int label 
    vec2cls = {"0"*(num_class-1):0}
    for i in range(num_class-1):
        key = "0"*i + "1" + "0"*(num_class-2-i)
        vec2cls[key] = i+1
    # print(vec2cls)
    vec2cls['1'*(num_class-1)] = 0 # do all equal not do 
    # vec2cls = {"00":0,"10":1,"01":2} # tested for class_num = 2 
    num_labels = idm_actions[action_name_keys[0]].size # assert same length: video_num x frame_per_video
    # if not single in first dim, we should perform flattn 
    
    # construct one-hot vector
    idm_action_string = [[str(int(i)) for i in idm_actions[action_name].flatten()] for action_name in action_name_keys]
    try:
        labels = [vec2cls["".join([idm_action_string[j][i] for j in range(num_class-1)])] for i in range(num_labels)]
    except:
        conflicts_num = sum([ i=='1' and j=='1' for i,j in zip(idm_action_string[0],idm_action_string[1])])
        print(f"detect conflict prediction: {conflicts_num}")
        return None 
        
    labels = np.array(labels)
    return labels

def define_exclusive_classification_task(predicted_actions:dict,recorded_actions:dict,calculate_hot_bar = False) -> dict:
    subtasks = {"multi_class":[("back","forward"),# 01,00,10,  
                                ("left","right"),
                                ("sneak","sprint"),
                                ],
                    "binary_class":["use","attack","jump","drop"]
    }
    if calculate_hot_bar:
        subtasks["multi_class"]=[("hotbar.1","hotbar.2","hotbar.3","hotbar.4","hotbar.5","hotbar.6","hotbar.7","hotbar.8","hotbar.9")]
    subtasks_labels = {}
    for class_pair in subtasks["multi_class"]:
        class_num = len(class_pair) + 1 # len = 2 has 00 01 10 
        # convert to strings 
        # convert to classification
        pred_labels = construct_classification_labels(predicted_actions, class_pair, class_num)
        rec_labels = construct_classification_labels(recorded_actions, class_pair, class_num)
        if pred_labels is None or rec_labels is None:
            print(f"detect conflict prediction: {pred_labels} and {rec_labels}")
            continue
        subtasks_labels[class_pair] = {"class_num":class_num,
                                       "pred_labels":pred_labels,
                                       "rec_labels":rec_labels
                                       }
    for binary_task in subtasks["binary_class"]:
        pred_labels =  predicted_actions[binary_task] 
        rec_labels =   recorded_actions[binary_task] 
        subtasks_labels[binary_task] = {"class_num":2,
                                       "pred_labels":pred_labels,
                                       "rec_labels":rec_labels
                                       }
    return subtasks_labels

def classification_metric(pred_labels, rec_labels, class_num):
    ## compute macro and micro score for both tri classification and binary classification
    ## the difference between macro and micro precision and binary precision for binary task is :
    ## the binary precision only compute label with 1 ; but micro and marco compute 0, 1 and then average them 
    ## to align with tri-classification we use average="macro" and average="micro"
    precision_micro = precision_score(rec_labels, pred_labels, average="micro", zero_division=0)
    recall_micro = recall_score(rec_labels, pred_labels, average="micro", zero_division=0)
    f1_micro = f1_score(rec_labels, pred_labels, average="micro", zero_division=0)
    
    precision_macro = precision_score(rec_labels, pred_labels, average="macro", zero_division=0)
    recall_macro = recall_score(rec_labels, pred_labels, average="macro", zero_division=0)
    f1_macro = f1_score(rec_labels, pred_labels, average="macro", zero_division=0)
    return {
        "precision_micro": precision_micro,
        "recall_micro": recall_micro,
        "f1_micro": f1_micro,
        "precision_macro": precision_macro,
        "recall_macro": recall_macro,
        "f1_macro": f1_macro,
        "class_num": class_num
    }

def aggregate_actions(actions:list) -> dict:
    return_dict = {}
    for action in actions:
        for key in action:
            if key not in return_dict:
                return_dict[key] = []
            return_dict[key].append(action[key])
    for key in return_dict:
        return_dict[key] = np.array(return_dict[key]).reshape(-1)
    return return_dict

def idm_prediction(agent, video_path,json_path, infer_demo_num, n_frames):
    th.cuda.empty_cache()
    full_json_data = load_action_jsonl(json_path)
    json_data = full_json_data[infer_demo_num:infer_demo_num+n_frames]
    recorded_actions = [json_action_to_env_action(i)[0] for i in json_data]
    recorded_actions = aggregate_actions(recorded_actions)
    frames = []
    cap = cv2.VideoCapture(video_path)
    for _ in range(n_frames):
        ret, frame = cap.read()
        if not ret:
            print(f"[Error] loading frames in {video_path} returing {_}")
            return None,None
        # BGR -> RGB
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame)
    frames = np.stack(frames)
    predicted_actions = agent.predict_actions(frames)
    for key in predicted_actions:
        if key == "camera":
            continue
        predicted_actions[key] = np.array(predicted_actions[key]).reshape(-1)
    return predicted_actions,recorded_actions

def camera_loss(predicted_actions,recorded_actions):
    from lib.actions import CameraQuantizer
    cam_quantizer = CameraQuantizer(
    camera_binsize=2,
    camera_maxval=10,
    mu=10,
    quantization_scheme="mu_law")
    # import pdb;pdb.set_trace()
    cam_pred_token=cam_quantizer.discretize(predicted_actions['camera'].reshape(-1))
    cam_gt_token  =cam_quantizer.discretize(np.array(recorded_actions['camera']))
    camera_bin_loss = np.abs(cam_pred_token-cam_gt_token).mean()
    return {
        "camera_bin_loss":camera_bin_loss
    }

if __name__ == "__main__":
    parser = ArgumentParser("Evaluate IDM quality for MC-LVM ")
    parser.add_argument("--weights", type=str, required=True, help="[IDM model config] Path to the '.weights' file to be loaded.")
    parser.add_argument("--model", type=str, required=True, help="[IDM model config] Path to the '.model' file to be loaded.")
    parser.add_argument("--jsonl-path", type=str, required=True, help="[Eval Config] Path to .jsonl contains actions.")
    parser.add_argument("--video-path", type=str, required=True, help="[Eval Config] Path to a .mp4 file.")
    parser.add_argument("--infer-demo-num", type=int, default=0, help="[Inference Config] Number of frames to skip before starting evaluation.")
    parser.add_argument("--n-frames", type=int, default=32, help="[Inference Config] Number of frames to generation.")
    parser.add_argument("--output-file", type=str, default="[Eval Config] output/action_loss.jsonl", help="[Eval Config] Path to save the action loss.")
    args = parser.parse_args()
    os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
    evaluate_IDM_quality(args.model, args.weights,args.jsonl_path ,args.video_path, args.infer_demo_num,args.n_frames,args.output_file)


================================================
FILE: metrics/IDM/lib/__init__.py
================================================


================================================
FILE: metrics/IDM/lib/action_head.py
================================================
import logging
from typing import Any, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from gym3.types import DictType, Discrete, Real, TensorType, ValType

LOG0 = -100


def fan_in_linear(module: nn.Module, scale=1.0, bias=True):
    """Fan-in init"""
    module.weight.data *= scale / module.weight.norm(dim=1, p=2, keepdim=True)

    if bias:
        module.bias.data *= 0


class ActionHead(nn.Module):
    """Abstract base class for action heads compatible with forc"""

    def forward(self, input_data: torch.Tensor) -> Any:
        """
        Just a forward pass through this head
        :returns pd_params - parameters describing the probability distribution
        """
        raise NotImplementedError

    def logprob(self, action_sample: torch.Tensor, pd_params: torch.Tensor) -> torch.Tensor:
        """Logartithm of probability of sampling `action_sample` from a probability described by `pd_params`"""
        raise NotImplementedError

    def entropy(self, pd_params: torch.Tensor) -> torch.Tensor:
        """Entropy of this distribution"""
        raise NotImplementedError

    def sample(self, pd_params: torch.Tensor, deterministic: bool = False) -> Any:
        """
        Draw a sample from probability distribution given by those params

        :param pd_params Parameters of a probability distribution
        :param deterministic Whether to return a stochastic sample or deterministic mode of a distribution
        """
        raise NotImplementedError

    def kl_divergence(self, params_q: torch.Tensor, params_p: torch.Tensor) -> torch.Tensor:
        """KL divergence between two distribution described by these two params"""
        raise NotImplementedError


class DiagGaussianActionHead(ActionHead):
    """
    Action head where actions are normally distributed uncorrelated variables with specific means and variances.

    Means are calculated directly from the network while standard deviations are a parameter of this module
    """

    LOG2PI = np.log(2.0 * np.pi)

    def __init__(self, input_dim: int, num_dimensions: int):
        super().__init__()

        self.input_dim = input_dim
        self.num_dimensions = num_dimensions

        self.linear_layer = nn.Linear(input_dim, num_dimensions)
        self.log_std = nn.Parameter(torch.zeros(num_dimensions), requires_grad=True)

    def reset_parameters(self):
        init.orthogonal_(self.linear_layer.weight, gain=0.01)
        init.constant_(self.linear_layer.bias, 0.0)

    def forward(self, input_data: torch.Tensor, mask=None) -> torch.Tensor:
        assert not mask, "Can not use a mask in a gaussian action head"
        means = self.linear_layer(input_data)
        # Unsqueeze many times to get to the same shape
        logstd = self.log_std[(None,) * (len(means.shape) - 1)]

        mean_view, logstd = torch.broadcast_tensors(means, logstd)

        return torch.stack([mean_view, logstd], dim=-1)

    def logprob(self, action_sample: torch.Tensor, pd_params: torch.Tensor) -> torch.Tensor:
        """Log-likelihood"""
        means = pd_params[..., 0]
        log_std = pd_params[..., 1]

        std = torch.exp(log_std)

        z_score = (action_sample - means) / std

        return -(0.5 * ((z_score ** 2 + self.LOG2PI).sum(dim=-1)) + log_std.sum(dim=-1))

    def entropy(self, pd_params: torch.Tensor) -> torch.Tensor:
        """
        Categorical distribution entropy calculation - sum probs * log(probs).
        In case of diagonal gaussian distribution - 1/2 log(2 pi e sigma^2)
        """
        log_std = pd_params[..., 1]
        return (log_std + 0.5 * (self.LOG2PI + 1)).sum(dim=-1)

    def sample(self, pd_params: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
        means = pd_params[..., 0]
        log_std = pd_params[..., 1]

        if deterministic:
            return means
        else:
            return torch.randn_like(means) * torch.exp(log_std) + means

    def kl_divergence(self, params_q: torch.Tensor, params_p: torch.Tensor) -> torch.Tensor:
        """
        Categorical distribution KL divergence calculation
        KL(Q || P) = sum Q_i log (Q_i / P_i)

        Formula is:
        log(sigma_p) - log(sigma_q) + (sigma_q^2 + (mu_q - mu_p)^2))/(2 * sigma_p^2)
        """
        means_q = params_q[..., 0]
        log_std_q = params_q[..., 1]

        means_p = params_p[..., 0]
        log_std_p = params_p[..., 1]

        std_q = torch.exp(log_std_q)
        std_p = torch.exp(log_std_p)

        kl_div = log_std_p - log_std_q + (std_q ** 2 + (means_q - means_p) ** 2) / (2.0 * std_p ** 2) - 0.5

        return kl_div.sum(dim=-1, keepdim=True)


class CategoricalActionHead(ActionHead):
    """Action head with categorical actions"""

    def __init__(
        self, input_dim: int, shape: Tuple[int], num_actions: int, builtin_linear_layer: bool = True, temperature: float = 1.0
    ):
        super().__init__()

        self.input_dim = input_dim
        self.num_actions = num_actions
        self.output_shape = shape + (num_actions,)
        self.temperature = temperature

        if builtin_linear_layer:
            self.linear_layer = nn.Linear(input_dim, np.prod(self.output_shape))
        else:
            assert (
                input_dim == num_actions
            ), f"If input_dim ({input_dim}) != num_actions ({num_actions}), you need a linear layer to convert them."
            self.linear_layer = None

    def reset_parameters(self):
        if self.linear_layer is not None:
            init.orthogonal_(self.linear_layer.weight, gain=0.01)
            init.constant_(self.linear_layer.bias, 0.0)
            finit.fan_in_linear(self.linear_layer, scale=0.01)

    def forward(self, input_data: torch.Tensor, mask=None) -> Any:
        if self.linear_layer is not None:
            flat_out = self.linear_layer(input_data)
        else:
            flat_out = input_data
        shaped_out = flat_out.reshape(flat_out.shape[:-1] + self.output_shape)
        shaped_out /= self.temperature
        if mask is not None:
            shaped_out[~mask] = LOG0

        # Convert to float32 to avoid RuntimeError: "log_softmax_lastdim_kernel_impl" not implemented for 'Half'
        return F.log_softmax(shaped_out.float(), dim=-1)

    def logprob(self, actions: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
        value = actions.long().unsqueeze(-1)
        value, log_pmf = torch.broadcast_tensors(value, logits)
        value = value[..., :1]
        result = log_pmf.gather(-1, value).squeeze(-1)
        # result is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it.
        for _ in self.output_shape[:-1]:
            result = result.sum(dim=-1)
        return result

    def entropy(self, logits: torch.Tensor) -> torch.Tensor:
        """Categorical distribution entropy calculation - sum probs * log(probs)"""
        probs = torch.exp(logits)
        entropy = -torch.sum(probs * logits, dim=-1)
        # entropy is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it.
        for _ in self.output_shape[:-1]:
            entropy = entropy.sum(dim=-1)
        return entropy

    def sample(self, logits: torch.Tensor, deterministic: bool = False) -> Any:
        if deterministic:
            return torch.argmax(logits, dim=-1)
        else:
            # Gumbel-Softmax trick.
            u = torch.rand_like(logits)
            # In float16, if you have around 2^{float_mantissa_bits} logits, sometimes you'll sample 1.0
            # Then the log(-log(1.0)) will give -inf when it should give +inf
            # This is a silly hack to get around that.
            # This hack does not skew the probability distribution, because this event can't possibly win the argmax.
            u[u == 1.0] = 0.999

            return torch.argmax(logits - torch.log(-torch.log(u)), dim=-1)

    def kl_divergence(self, logits_q: torch.Tensor, logits_p: torch.Tensor) -> torch.Tensor:
        """
        Categorical distribution KL divergence calculation
        KL(Q || P) = sum Q_i log (Q_i / P_i)
        When talking about logits this is:
        sum exp(Q_i) * (Q_i - P_i)
        """
        kl = (torch.exp(logits_q) * (logits_q - logits_p)).sum(-1, keepdim=True)
        # kl is per-entry, still of size self.output_shape; we need to reduce of the rest of it.
        for _ in self.output_shape[:-1]:
            kl = kl.sum(dim=-2)  # dim=-2 because we use keepdim=True above.
        return kl


class DictActionHead(nn.ModuleDict):
    """Action head with multiple sub-actions"""

    def reset_parameters(self):
        for subhead in self.values():
            subhead.reset_parameters()

    def forward(self, input_data: torch.Tensor, **kwargs) -> Any:
        """
        :param kwargs: each kwarg should be a dict with keys corresponding to self.keys()
                e.g. if this ModuleDict has submodules keyed by 'A', 'B', and 'C', we could call:
                    forward(input_data, foo={'A': True, 'C': False}, bar={'A': 7}}
                Then children will be called with:
                    A: forward(input_data, foo=True, bar=7)
                    B: forward(input_data)
                    C: forward(input_Data, foo=False)
        """
        result = {}
        for head_name, subhead in self.items():
            head_kwargs = {
                kwarg_name: kwarg[head_name]
                for kwarg_name, kwarg in kwargs.items()
                if kwarg is not None and head_name in kwarg
            }
            result[head_name] = subhead(input_data, **head_kwargs)
        return result

    def logprob(self, actions: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
        return sum(subhead.logprob(actions[k], logits[k]) for k, subhead in self.items())

    def sample(self, logits: torch.Tensor, deterministic: bool = False) -> Any:
        return {k: subhead.sample(logits[k], deterministic) for k, subhead in self.items()}

    def entropy(self, logits: torch.Tensor) -> torch.Tensor:
        return sum(subhead.entropy(logits[k]) for k, subhead in self.items())

    def kl_divergence(self, logits_q: torch.Tensor, logits_p: torch.Tensor) -> torch.Tensor:
        return sum(subhead.kl_divergence(logits_q[k], logits_p[k]) for k, subhead in self.items())


def make_action_head(ac_space: ValType, pi_out_size: int, temperature: float = 1.0):
    """Helper function to create an action head corresponding to the environment action space"""
    if isinstance(ac_space, TensorType):
        if isinstance(ac_space.eltype, Discrete):
            return CategoricalActionHead(pi_out_size, ac_space.shape, ac_space.eltype.n, temperature=temperature)
        elif isinstance(ac_space.eltype, Real):
            if temperature != 1.0:
                logging.warning("Non-1 temperature not implemented for DiagGaussianActionHead.")
            assert len(ac_space.shape) == 1, "Nontrivial shapes not yet implemented."
            return DiagGaussianActionHead(pi_out_size, ac_space.shape[0])
    elif isinstance(ac_space, DictType):
        return DictActionHead({k: make_action_head(v, pi_out_size, temperature) for k, v in ac_space.items()})
    raise NotImplementedError(f"Action space of type {type(ac_space)} is not supported")


================================================
FILE: metrics/IDM/lib/action_mapping.py
================================================
import abc
import itertools
from collections import OrderedDict
from typing import Dict, List

import numpy as np
from gym3.types import DictType, Discrete, TensorType

from lib.actions import Buttons


class ActionMapping(abc.ABC):
    """Class that maps between the standard MC factored action space and a new one you define!

    :param n_camera_bins: Need to specify this to define the original ac space for stats code
    """

    # This is the default buttons groups, it can be changed for your action space
    BUTTONS_GROUPS = OrderedDict(
        hotbar=["none"] + [f"hotbar.{i}" for i in range(1, 10)],
        fore_back=["none", "forward", "back"],
        left_right=["none", "left", "right"],
        sprint_sneak=["none", "sprint", "sneak"],
        use=["none", "use"],
        drop=["none", "drop"],
        attack=["none", "attack"],
        jump=["none", "jump"],
    )

    def __init__(self, n_camera_bins: int = 11):
        assert n_camera_bins % 2 == 1, "n_camera_bins should be odd"
        self.n_camera_bins = n_camera_bins
        self.camera_null_bin = n_camera_bins // 2
        self.stats_ac_space = DictType(
            **{
                "buttons": TensorType(shape=(len(Buttons.ALL),), eltype=Discrete(2)),
                "camera": TensorType(shape=(2,), eltype=Discrete(n_camera_bins)),
            }
        )

    @abc.abstractmethod
    def from_factored(self, ac: Dict) -> Dict:
        """Converts a factored action (ac) to the new space

        :param ac: Dictionary of actions that must have a batch dimension
        """
        pass

    @abc.abstractmethod
    def to_factored(self, ac: Dict) -> Dict:
        """Converts an action in the new space (ac) to the factored action space.

        :param ac: Dictionary of actions that must have a batch dimension
        """
        pass

    @abc.abstractmethod
    def get_action_space_update(self):
        """Return a magym (gym3) action space. This will be used to update the env action space."""
        pass

    @abc.abstractmethod
    def get_zero_action(self):
        """Return the zero or null action for this action space"""
        pass

    def factored_buttons_to_groups(self, ac_buttons: np.ndarray, button_group: List[str]) -> List[str]:
        """For a mutually exclusive group of buttons in button_group, find which option
        in the group was chosen. Assumes that each button group has the option of 'none'
        meaning that no button in the group was pressed.

        :param ac_buttons: button actions from the factored action space. Should dims [B, len(Buttons.ALL)]
        :param button_group: List of buttons in a mutually exclusive group. Each item in the
            list should appear in Buttons.ALL except for the special case 'none' which means
            no button in the group was pressed. e.g. ['none', 'forward', 'back']. For now
            'none' must be the first element of button_group

        Returns a list of length B, where each element is an item from button_group.
        """
        assert ac_buttons.shape[1] == len(
            Buttons.ALL
        ), f"There should be {len(Buttons.ALL)} buttons in the factored buttons space"
        assert button_group[0] == "none", "This function only works if 'none' is in button_group"
        # Actions in ac_buttons with order according to button_group
        group_indices = [Buttons.ALL.index(b) for b in button_group if b != "none"]
        ac_choices = ac_buttons[:, group_indices]

        # Special cases for forward/back, left/right where mutual press means do neither
        if "forward" in button_group and "back" in button_group:
            ac_choices[np.all(ac_choices, axis=-1)] = 0
        if "left" in button_group and "right" in button_group:
            ac_choices[np.all(ac_choices, axis=-1)] = 0
        ac_non_zero = np.where(ac_choices)
        ac_choice = ["none" for _ in range(ac_buttons.shape[0])]
        # Iterate over the non-zero indices so that if two buttons in a group were pressed at the same time
        # we give priority to the button later in the group. E.g. if hotbar.1 and hotbar.2 are pressed during the same
        # timestep, hotbar.2 is marked as pressed
        for index, action in zip(ac_non_zero[0], ac_non_zero[1]):
            ac_choice[index] = button_group[action + 1]  # the zero'th index will mean no button pressed
        return ac_choice

class IDMActionMapping(ActionMapping):
    """For IDM, but essentially this is just an identity mapping"""
    def from_factored(self, ac: Dict) -> Dict:
        return ac

    def to_factored(self, ac: Dict) -> Dict:
        return ac

    def get_action_space_update(self):
        """Return a magym (gym3) action space. This will be used to update the env action space."""
        return {
            "buttons": TensorType(shape=(len(Buttons.ALL),), eltype=Discrete(2)),
            "camera": TensorType(shape=(2,), eltype=Discrete(self.n_camera_bins)),
        }

    def get_zero_action(self):
        raise NotImplementedError()

class CameraHierarchicalMapping(ActionMapping):
    """Buttons are joint as in ButtonsJointMapping, but now a camera on/off meta action is added into this joint space.
    When this meta action is triggered, the separate camera head chooses a camera action which is also now a joint space.

    :param n_camera_bins: number of camera bins in the factored space
    """

    # Add camera meta action to BUTTONS_GROUPS
    BUTTONS_GROUPS = ActionMapping.BUTTONS_GROUPS.copy()
    BUTTONS_GROUPS["camera"] = ["none", "camera"]
    BUTTONS_COMBINATIONS = list(itertools.product(*BUTTONS_GROUPS.values())) + ["inventory"]
    BUTTONS_COMBINATION_TO_IDX = {comb: i for i, comb in enumerate(BUTTONS_COMBINATIONS)}
    BUTTONS_IDX_TO_COMBINATION = {i: comb for i, comb in enumerate(BUTTONS_COMBINATIONS)}

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.camera_groups = OrderedDict(
            camera_x=[f"camera_x{i}" for i in range(self.n_camera_bins)],
            camera_y=[f"camera_y{i}" for i in range(self.n_camera_bins)],
        )
        self.camera_combinations = list(itertools.product(*self.camera_groups.values()))
        self.camera_combination_to_idx = {comb: i for i, comb in enumerate(self.camera_combinations)}
        self.camera_idx_to_combination = {i: comb for i, comb in enumerate(self.camera_combinations)}
        self.camera_null_idx = self.camera_combination_to_idx[
            (f"camera_x{self.camera_null_bin}", f"camera_y{self.camera_null_bin}")
        ]
        self._null_action = {
            "buttons": self.BUTTONS_COMBINATION_TO_IDX[tuple("none" for _ in range(len(self.BUTTONS_GROUPS)))]
        }
        self._precompute_to_factored()

    def _precompute_to_factored(self):
        """Precompute the joint action -> factored action matrix."""
        button_dim = self.stats_ac_space["buttons"].size
        self.BUTTON_IDX_TO_FACTORED = np.zeros((len(self.BUTTONS_IDX_TO_COMBINATION), button_dim), dtype=int)
        self.BUTTON_IDX_TO_CAMERA_META_OFF = np.zeros((len(self.BUTTONS_IDX_TO_COMBINATION)), dtype=bool)
        self.CAMERA_IDX_TO_FACTORED = np.zeros((len(self.camera_idx_to_combination), 2), dtype=int)

        # Pre compute Buttons
        for jnt_ac, button_comb in self.BUTTONS_IDX_TO_COMBINATION.items():
            new_button_ac = np.zeros(len(Buttons.ALL), dtype="i")
            if button_comb == "inventory":
                new_button_ac[Buttons.ALL.index("inventory")] = 1
            else:
                for group_choice in button_comb[:-1]:  # Last one is camera
                    if group_choice != "none":
                        new_button_ac[Buttons.ALL.index(group_choice)] = 1

                if button_comb[-1] != "camera":  # This means camera meta action is off
                    self.BUTTON_IDX_TO_CAMERA_META_OFF[jnt_ac] = True
            self.BUTTON_IDX_TO_FACTORED[jnt_ac] = new_button_ac

        # Pre compute camera
        for jnt_ac, camera_comb in self.camera_idx_to_combination.items():
            new_camera_ac = np.ones((2), dtype="i") * self.camera_null_bin
            new_camera_ac[0] = self.camera_groups["camera_x"].index(camera_comb[0])
            new_camera_ac[1] = self.camera_groups["camera_y"].index(camera_comb[1])
            self.CAMERA_IDX_TO_FACTORED[jnt_ac] = new_camera_ac

    def from_factored(self, ac: Dict) -> Dict:
        """Converts a factored action (ac) to the new space. Assumes ac has a batch dim"""
        assert ac["camera"].ndim == 2, f"bad camera label, {ac['camera']}"
        assert ac["buttons"].ndim == 2, f"bad buttons label, {ac['buttons']}"
        # Get button choices for everything but camera
        choices_by_group = OrderedDict(
            (k, self.factored_buttons_to_groups(ac["buttons"], v)) for k, v in self.BUTTONS_GROUPS.items() if k != "camera"
        )
        # Set camera "on off" action based on whether non-null camera action was given
        camera_is_null = np.all(ac["camera"] == self.camera_null_bin, axis=1)
        choices_by_group["camera"] = ["none" if is_null else "camera" for is_null in camera_is_null]

        new_button_ac = []
        new_camera_ac = []
        for i in range(ac["buttons"].shape[0]):
            # Buttons
            key = tuple([v[i] for v in choices_by_group.values()])
            if ac["buttons"][i, Buttons.ALL.index("inventory")] == 1:
                key = "inventory"
            new_button_ac.append(self.BUTTONS_COMBINATION_TO_IDX[key])

            # Camera -- inventory is also exclusive with camera
            if key == "inventory":
                key = (
                    f"camera_x{self.camera_null_bin}",
                    f"camera_y{self.camera_null_bin}",
                )
            else:
                key = (f"camera_x{ac['camera'][i][0]}", f"camera_y{ac['camera'][i][1]}")
            new_camera_ac.append(self.camera_combination_to_idx[key])

        return dict(
            buttons=np.array(new_button_ac)[:, None],
            camera=np.array(new_camera_ac)[:, None],
        )

    def to_factored(self, ac: Dict) -> Dict:
        """Converts an action in the new space (ac) to the factored action space. Assumes ac has a batch dim"""
        assert ac["camera"].shape[-1] == 1
        assert ac["buttons"].shape[-1] == 1

        new_button_ac = self.BUTTON_IDX_TO_FACTORED[np.squeeze(ac["buttons"], -1)]
        camera_off = self.BUTTON_IDX_TO_CAMERA_META_OFF[np.squeeze(ac["buttons"], -1)]
        new_camera_ac = self.CAMERA_IDX_TO_FACTORED[np.squeeze(ac["camera"], -1)]
        new_camera_ac[camera_off] = self.camera_null_bin

        return dict(buttons=new_button_ac, camera=new_camera_ac)

    def get_action_space_update(self):
        return {
            "camera": TensorType(shape=(1,), eltype=Discrete(len(self.camera_combinations))),
            "buttons": TensorType(shape=(1,), eltype=Discrete(len(self.BUTTONS_COMBINATIONS))),
        }

    def get_zero_action(self):
        return self._null_action



================================================
FILE: metrics/IDM/lib/actions.py
================================================
import attr
# import minerl.herobraine.hero.mc as mc
import numpy as np

from lib.minecraft_util import store_args


class Buttons:
    ATTACK = "attack"
    BACK = "back"
    FORWARD = "forward"
    JUMP = "jump"
    LEFT = "left"
    RIGHT = "right"
    SNEAK = "sneak"
    SPRINT = "sprint"
    USE = "use"
    DROP = "drop"
    INVENTORY = "inventory"

    ALL = [
        ATTACK,
        BACK,
        FORWARD,
        JUMP,
        LEFT,
        RIGHT,
        SNEAK,
        SPRINT,
        USE,
        DROP,
        INVENTORY,
    ] + [f"hotbar.{i}" for i in range(1, 10)]


class SyntheticButtons:
    # Composite / scripted actions
    CHANNEL_ATTACK = "channel-attack"

    ALL = [CHANNEL_ATTACK]


class QuantizationScheme:
    LINEAR = "linear"
    MU_LAW = "mu_law"


@attr.s(auto_attribs=True)
class CameraQuantizer:
    """
    A camera quantizer that discretizes and undiscretizes a continuous camera input with y (pitch) and x (yaw) components.

    Parameters:
    - camera_binsize: The size of the bins used for quantization. In case of mu-law quantization, it corresponds to the average binsize.
    - camera_maxval: The maximum value of the camera action.
    - quantization_scheme: The quantization scheme to use. Currently, two quantization schemes are supported:
    - Linear quantization (default): Camera actions are split uniformly into discrete bins
    - Mu-law quantization: Transforms the camera action using mu-law encoding (https://en.wikipedia.org/wiki/%CE%9C-law_algorithm)
    followed by the same quantization scheme used by the linear scheme.
    - mu: Mu is the parameter that defines the curvature of the mu-law encoding. Higher values of
    mu will result in a sharper transition near zero. Below are some reference values listed
    for choosing mu given a constant maxval and a desired max_precision value.
    maxval = 10 | max_precision = 0.5  | μ ≈ 2.93826
    maxval = 10 | max_precision = 0.4  | μ ≈ 4.80939
    maxval = 10 | max_precision = 0.25 | μ ≈ 11.4887
    maxval = 20 | max_precision = 0.5  | μ ≈ 2.7
    maxval = 20 | max_precision = 0.4  | μ ≈ 4.39768
    maxval = 20 | max_precision = 0.25 | μ ≈ 10.3194
    maxval = 40 | max_precision = 0.5  | μ ≈ 2.60780
    maxval = 40 | max_precision = 0.4  | μ ≈ 4.21554
    maxval = 40 | max_precision = 0.25 | μ ≈ 9.81152
    """

    camera_maxval: int
    camera_binsize: int
    quantization_scheme: str = attr.ib(
        default=QuantizationScheme.LINEAR,
        validator=attr.validators.in_([QuantizationScheme.LINEAR, QuantizationScheme.MU_LAW]),
    )
    mu: float = attr.ib(default=5)

    def discretize(self, xy):
        xy = np.clip(xy, -self.camera_maxval, self.camera_maxval)

        if self.quantization_scheme == QuantizationScheme.MU_LAW:
            xy = xy / self.camera_maxval
            v_encode = np.sign(xy) * (np.log(1.0 + self.mu * np.abs(xy)) / np.log(1.0 + self.mu))
            v_encode *= self.camera_maxval
            xy = v_encode

        # Quantize using linear scheme
        return np.round((xy + self.camera_maxval) / self.camera_binsize).astype(np.int64)

    def undiscretize(self, xy):
        xy = xy * self.camera_binsize - self.camera_maxval

        if self.quantization_scheme == QuantizationScheme.MU_LAW:
            xy = xy / self.camera_maxval
            v_decode = np.sign(xy) * (1.0 / self.mu) * ((1.0 + self.mu) ** np.abs(xy) - 1.0)
            v_decode *= self.camera_maxval
            xy = v_decode
        return xy


class ActionTransformer:
    """Transforms actions between internal array and minerl env format."""

    @store_args
    def __init__(
        self,
        camera_maxval=10,
        camera_binsize=2,
        camera_quantization_scheme="linear",
        camera_mu=5,
    ):
        self.quantizer = CameraQuantizer(
            camera_maxval=camera_maxval,
            camera_binsize=camera_binsize,
            quantization_scheme=camera_quantization_scheme,
            mu=camera_mu,
        )

    def camera_zero_bin(self):
        return self.camera_maxval // self.camera_binsize

    def discretize_camera(self, xy):
        return self.quantizer.discretize(xy)

    def undiscretize_camera(self, pq):
        return self.quantizer.undiscretize(pq)

    def item_embed_id_to_name(self, item_id):
        return mc.MINERL_ITEM_MAP[item_id]

    def dict_to_numpy(self, acs):
        """
        Env format to policy output format.
        """
        act = {
            "buttons": np.stack([acs.get(k, 0) for k in Buttons.ALL], axis=-1),
            "camera": self.discretize_camera(acs["camera"]),
        }
        if not self.human_spaces:
            act.update(
                {
                    "synthetic_buttons": np.stack([acs[k] for k in SyntheticButtons.ALL], axis=-1),
                    "place": self.item_embed_name_to_id(acs["place"]),
                    "equip": self.item_embed_name_to_id(acs["equip"]),
                    "craft": self.item_embed_name_to_id(acs["craft"]),
                }
            )
        return act

    def numpy_to_dict(self, acs):
        """
        Numpy policy output to env-compatible format.
        """
        assert acs["buttons"].shape[-1] == len(
            Buttons.ALL
        ), f"Mismatched actions: {acs}; expected {len(Buttons.ALL)}:\n(  {Buttons.ALL})"
        out = {name: acs["buttons"][..., i] for (i, name) in enumerate(Buttons.ALL)}

        out["camera"] = self.undiscretize_camera(acs["camera"])

        return out

    def policy2env(self, acs):
        acs = self.numpy_to_dict(acs)
        return acs

    def env2policy(self, acs):
        nbatch = acs["camera"].shape[0]
        dummy = np.zeros((nbatch,))
        out = {
            "camera": self.discretize_camera(acs["camera"]),
            "buttons": np.stack([acs.get(k, dummy) for k in Buttons.ALL], axis=-1),
        }
        return out


================================================
FILE: metrics/IDM/lib/impala_cnn.py
================================================
import math
from copy import deepcopy
from typing import Dict, List, Optional

from torch import nn
from torch.nn import functional as F

from lib import misc
from lib import torch_util as tu
from lib.util import FanInInitReLULayer


class CnnBasicBlock(nn.Module):
    """
    Residual basic block, as in ImpalaCNN. Preserves channel number and shape
    :param inchan: number of input channels
    :param init_scale: weight init scale multiplier
    """

    def __init__(
        self,
        inchan: int,
        init_scale: float = 1,
        log_scope="",
        init_norm_kwargs: Dict = {},
        **kwargs,
    ):
        super().__init__()
        self.inchan = inchan
        s = math.sqrt(init_scale)
        self.conv0 = FanInInitReLULayer(
            self.inchan,
            self.inchan,
            kernel_size=3,
            padding=1,
            init_scale=s,
            log_scope=f"{log_scope}/conv0",
            **init_norm_kwargs,
        )
        self.conv1 = FanInInitReLULayer(
            self.inchan,
            self.inchan,
            kernel_size=3,
            padding=1,
            init_scale=s,
            log_scope=f"{log_scope}/conv1",
            **init_norm_kwargs,
        )

    def forward(self, x):
        x = x + self.conv1(self.conv0(x))
        return x


class CnnDownStack(nn.Module):
    """
    Downsampling stack from Impala CNN.
    :param inchan: number of input channels
    :param nblock: number of residual blocks after downsampling
    :param outchan: number of output channels
    :param init_scale: weight init scale multiplier
    :param pool: if true, downsample with max pool
    :param post_pool_groups: if not None, normalize with group norm with this many groups
    :param kwargs: remaining kwargs are passed into the blocks and layers
    """

    name = "Impala_CnnDownStack"

    def __init__(
        self,
        inchan: int,
        nblock: int,
        outchan: int,
        init_scale: float = 1,
        pool: bool = True,
        post_pool_groups: Optional[int] = None,
        log_scope: str = "",
        init_norm_kwargs: Dict = {},
        first_conv_norm=False,
        **kwargs,
    ):
        super().__init__()
        self.inchan = inchan
        self.outchan = outchan
        self.pool = pool
        first_conv_init_kwargs = deepcopy(init_norm_kwargs)
        if not first_conv_norm:
            first_conv_init_kwargs["group_norm_groups"] = None
            first_conv_init_kwargs["batch_norm"] = False
        self.firstconv = FanInInitReLULayer(
            inchan,
            outchan,
            kernel_size=3,
            padding=1,
            log_scope=f"{log_scope}/firstconv",
            **first_conv_init_kwargs,
        )
        self.post_pool_groups = post_pool_groups
        if post_pool_groups is not None:
            self.n = nn.GroupNorm(post_pool_groups, outchan)
        self.blocks = nn.ModuleList(
            [
                CnnBasicBlock(
                    outchan,
                    init_scale=init_scale / math.sqrt(nblock),
                    log_scope=f"{log_scope}/block{i}",
                    init_norm_kwargs=init_norm_kwargs,
                    **kwargs,
                )
                for i in range(nblock)
            ]
        )

    def forward(self, x):
        x = self.firstconv(x)
        if self.pool:
            x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
            if self.post_pool_groups is not None:
                x = self.n(x)
        x = tu.sequential(self.blocks, x, diag_name=self.name)
        return x

    def output_shape(self, inshape):
        c, h, w = inshape
        assert c == self.inchan
        if self.pool:
            return (self.outchan, (h + 1) // 2, (w + 1) // 2)
        else:
            return (self.outchan, h, w)


class ImpalaCNN(nn.Module):
    """
    :param inshape: input image shape (height, width, channels)
    :param chans: number of residual downsample stacks. Each element is the number of
        filters per convolution in the stack
    :param outsize: output hidden size
    :param nblock: number of residual blocks per stack. Each block has 2 convs and a residual
    :param init_norm_kwargs: arguments to be passed to convolutional layers. Options can be found
        in ypt.model.util:FanInInitReLULayer
    :param dense_init_norm_kwargs: arguments to be passed to convolutional layers. Options can be found
        in ypt.model.util:FanInInitReLULayer
    :param kwargs: remaining kwargs are passed into the CnnDownStacks
    """

    name = "ImpalaCNN"

    def __init__(
        self,
        inshape: List[int],
        chans: List[int],
        outsize: int,
        nblock: int,
        init_norm_kwargs: Dict = {},
        dense_init_norm_kwargs: Dict = {},
        first_conv_norm=False,
        **kwargs,
    ):
        super().__init__()
        h, w, c = inshape
        curshape = (c, h, w)
        self.stacks = nn.ModuleList()
        for i, outchan in enumerate(chans):
            stack = CnnDownStack(
                curshape[0],
                nblock=nblock,
                outchan=outchan,
                init_scale=math.sqrt(len(chans)),
                log_scope=f"downstack{i}",
                init_norm_kwargs=init_norm_kwargs,
                first_conv_norm=first_conv_norm if i == 0 else True,
                **kwargs,
            )
            self.stacks.append(stack)
            curshape = stack.output_shape(curshape)

        self.dense = FanInInitReLULayer(
            misc.intprod(curshape),
            outsize,
            layer_type="linear",
            log_scope="imapala_final_dense",
            init_scale=1.4,
            **dense_init_norm_kwargs,
        )
        self.outsize = outsize

    def forward(self, x):
        b, t = x.shape[:-3]
        x = x.reshape(b * t, *x.shape[-3:])
        x = misc.transpose(x, "bhwc", "bchw")
        x = tu.sequential(self.stacks, x, diag_name=self.name)
        x = x.reshape(b, t, *x.shape[1:])
        x = tu.flatten_image(x)
        x = self.dense(x)
        return x


================================================
FILE: metrics/IDM/lib/masked_attention.py
================================================
import functools

import torch as th
from torch import nn

import lib.xf as xf
from lib.minecraft_util import store_args
from lib.tree_util import tree_map


@functools.lru_cache()
def get_band_diagonal_mask(t: int, T: int, maxlen: int, batchsize: int, device: th.device) -> th.Tensor:
    """Returns a band diagonal mask which is causal (upper triangle is masked)
    and such that any frame can only view up to maxlen total past frames
    including the current frame.

    Example Masks: Here 0 means that frame is masked and we mask it by adding a huge number to the attention logits (see orc.xf)
        t = 3, T = 3, maxlen = 3
          T
        t 1 0 0 |  mask out T > t
          1 1 0 |
          1 1 1 |
        t = 3, T = 6, maxlen = 3
        t 0 1 1 1 0 0 |  mask out T > t
          0 0 1 1 1 0 |
          0 0 0 1 1 1 |

    Args:
        t: number of rows (presumably number of frames recieving gradient)
        T: number of cols (presumably t + past context that isn't being gradient updated)
        maxlen: maximum number of frames (including current frame) any frame can attend to
        batchsize: number of masks to return
        device: torch device to place mask on

    Returns:
        Boolean mask of shape (batchsize, t, T)
    """
    m = th.ones(t, T, dtype=bool)
    m.tril_(T - t)  # Mask out upper triangle
    if maxlen is not None and maxlen < T:  # Mask out lower triangle
        m.triu_(T - t - maxlen + 1)
    m_btT = m[None].repeat_interleave(batchsize, dim=0)
    m_btT = m_btT.to(device=device)
    return m_btT


def get_mask(first_b11: th.Tensor, state_mask: th.Tensor, t: int, T: int, maxlen: int, heads: int, device) -> th.Tensor:
    """Returns a band diagonal mask that respects masking past states (columns 0:T-t inclusive)
        if first_b11 is True. See get_band_diagonal_mask for how the base mask is computed.
        This function takes that mask and first zeros out any past context if first_b11 is True.

        Say our context is in chunks of length t (so here T = 4t). We see that in the second batch we recieved first=True
        context     t t t t
        first       F T F F
        Now, given this the mask should mask out anything prior to T < t; however since we don't have access to the past first_b11's
        we need to keep a state of the mask at those past timesteps. This is what state_mask is.

        In particular state_mask is a [b, t, T - t] mask matrix that contains the mask for the past T - t frames.

    Args: (See get_band_diagonal_mask for remaining args)
        first_b11: boolean tensor with shape [batchsize, 1, 1] indicating if the first timestep for each batch element had first=True
        state_mask: mask tensor of shape [b, t, T - t]
        t: number of mask rows (presumably number of frames for which we take gradient)
        T: number of mask columns (t + the number of past frames we keep in context)
        maxlen: actual context length
        heads: number of attention heads
        device: torch device

    Returns:
        m_btT: Boolean mask of shape (batchsize * heads, t, T)
        state_mask: updated state_mask
    """
    b = first_b11.shape[0]

    if state_mask is None:
        state_mask = th.zeros((b, 1, T - t), dtype=bool, device=device)

    m_btT = get_band_diagonal_mask(t, T, maxlen, b, device).clone()  # Should be shape B, t, T
    not_first = ~first_b11.to(device=device)
    m_btT[:, :, :-t] &= not_first  # Zero out anything in the past if first is true
    m_btT[:, :, :-t] &= state_mask
    m_bhtT = m_btT[:, None].repeat_interleave(heads, dim=1)
    m_btT = m_bhtT.reshape((b * heads), t, T)

    # Update state_mask such that it reflects the most recent first
    state_mask = th.cat(
        [
            state_mask[:, :, t:] & not_first,
            th.ones((b, 1, min(t, T - t)), dtype=bool, device=device),
        ],
        dim=-1,
    )

    return m_btT, state_mask


class MaskedAttention(nn.Module):
    """
    Transformer self-attention layer that removes frames from previous episodes from the hidden state under certain constraints.

    The constraints are:
    - The "first" flag can only be true for the first timestep of each batch. An assert will fire if other timesteps have first = True.

    input_size: The dimension of the input (which also happens to be the size of the output)
    memory_size: The number of frames to keep in the inner state. Note that when attending, we will be able to attend
                 to both the frames in the inner state (which presumably won't have gradients anymore) and the frames
                 in the batch. "mask" for some additional considerations on this.
    heads: The number of attention heads to use. Note that we will split the input into this number of heads, so
           input_size needs to be divisible by heads.
    timesteps: number of timesteps with which we'll be taking gradient
    mask: Can be "none" or "clipped_causal". "clipped_causal" is a normal causal mask but solves the following minor problem:
        if you have a state of length 128 and a batch of 128 frames, then the first frame of your batch will be able to
        attend to 128 previous frames, but the last one will be able to attend to 255 previous frames. In this example,
        "clipped_causal" will make it so that the last frame can only attend to 128 previous frames, so that there is no
        bias coming from the position in the batch. None simply allows you to attend to any frame in the state + batch,
        which means you can also attend to future frames.
    """

    @store_args
    def __init__(
        self,
        input_size,
        memory_size: int,
        heads: int,
        timesteps: int,
        mask: str = "clipped_causal",
        init_scale=1,
        norm="none",
        log_scope="sa",
        use_muP_factor=False,
    ):
        super().__init__()

        assert mask in {"none", "clipped_causal"}
        assert memory_size >= 0

        self.maxlen = memory_size - timesteps
        if mask == "none":
            mask = None

        self.orc_attn = xf.All2All(heads, self.maxlen, mask=mask is not None)
        self.orc_block = xf.SelfAttentionLayer(
            input_size,
            self.orc_attn,
            scale=init_scale,
            relattn=True,
            cache_keep_len=self.maxlen,
            norm=norm,
            log_scope=log_scope,
            use_muP_factor=use_muP_factor,
        )

    def initial_state(self, batchsize: int, device=None):
        """Return the initial state mask (None) and the initial state of the transformer (zerod out keys and queries)"""
        state = self.orc_block.initial_state(batchsize, initial_T=self.maxlen)
        state_mask = None
        if device is not None:
            state = tree_map(lambda x: x.to(device), state)
        return state_mask, state

    def forward(self, input_bte, first_bt, state):
        """Forward propagation of a single layer"""
        state_mask, xf_state = state
        t = first_bt.shape[1]
        if self.mask == "clipped_causal":
            new_mask, state_mask = get_mask(
                first_b11=first_bt[:, [[0]]],
                state_mask=state_mask,
                t=t,
                T=t + self.maxlen,
                maxlen=self.maxlen,
                heads=self.heads,
                device=input_bte.device,
            )
            self.orc_block.attn.mask = new_mask
        output, xf_state = self.orc_block(input_bte, xf_state)

        return output, (state_mask, xf_state)

    def get_log_keys(self):
        # These are logged in xf.SelfAttentionLayer
        return [f"activation_{stat}/{self.log_scope}/{k}" for k in ["K", "Q", "V", "A", "Aproj"] for stat in ["mean", "std"]]


================================================
FILE: metrics/IDM/lib/minecraft_util.py
================================================
import functools
import inspect
from typing import Optional, Tuple

import numpy as np
import torch

from lib.action_head import (CategoricalActionHead, DiagGaussianActionHead,
                             DictActionHead)


def store_args(method):
    """Stores provided method args as instance attributes."""
    argspec = inspect.getfullargspec(method)
    defaults = {}
    if argspec.defaults is not None:
        defaults = dict(zip(argspec.args[-len(argspec.defaults) :], argspec.defaults))
    if argspec.kwonlydefaults is not None:
        defaults.update(argspec.kwonlydefaults)
    arg_names = argspec.args[1:]

    @functools.wraps(method)
    def wrapper(*positional_args, **keyword_args):
        self = positional_args[0]
        # Get default arg values
        args = defaults.copy()
        # Add provided arg values
        for name, value in zip(arg_names, positional_args[1:]):
            args[name] = value
        args.update(keyword_args)
        self.__dict__.update(args)
        return method(*positional_args, **keyword_args)

    return wrapper


def get_norm_entropy_from_cat_head(module, name, masks, logits):
    # Note that the mask has already been applied to the logits at this point
    entropy = -torch.sum(torch.exp(logits) * logits, dim=-1)
    if name in masks:
        n = torch.sum(masks[name], dim=-1, dtype=torch.float)
        norm_entropy = entropy / torch.log(n)
        # When the mask only allows one option the normalized entropy makes no sense
        # as it is basically both maximal (the distribution is as uniform as it can be)
        # and minimal (there is no variance at all).
        # A such, we ignore them for purpose of calculating entropy.
        zero = torch.zeros_like(norm_entropy)
        norm_entropy = torch.where(n.eq(1.0), zero, norm_entropy)
        count = n.not_equal(1.0).int()
    else:
        n = torch.tensor(logits.shape[-1], dtype=torch.float)
        norm_entropy = entropy / torch.log(n)
        count = torch.ones_like(norm_entropy, dtype=torch.int)

    # entropy is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it.
    for _ in module.output_shape[:-1]:
        norm_entropy = norm_entropy.sum(dim=-1)
        count = count.sum(dim=-1)
    return norm_entropy, count


def get_norm_cat_entropy(module, masks, logits, template) -> Tuple[torch.Tensor, torch.Tensor]:
    entropy_sum = torch.zeros_like(template, dtype=torch.float)
    counts = torch.zeros_like(template, dtype=torch.int)
    for k, subhead in module.items():
        if isinstance(subhead, DictActionHead):
            entropy, count = get_norm_cat_entropy(subhead, masks, logits[k], template)
        elif isinstance(subhead, CategoricalActionHead):
            entropy, count = get_norm_entropy_from_cat_head(subhead, k, masks, logits[k])
        else:
            continue
        entropy_sum += entropy
        counts += count
    return entropy_sum, counts


def get_diag_guassian_entropy(module, logits, template) -> Optional[torch.Tensor]:
    entropy_sum = torch.zeros_like(template, dtype=torch.float)
    count = torch.zeros(1, device=template.device, dtype=torch.int)
    for k, subhead in module.items():
        if isinstance(subhead, DictActionHead):
            entropy_sum += get_diag_guassian_entropy(subhead, logits[k], template)
        elif isinstance(subhead, DiagGaussianActionHead):
            entropy_sum += module.entropy(logits)
        else:
            continue
        count += 1
    return entropy_sum / count


================================================
FILE: metrics/IDM/lib/misc.py
================================================
import numpy as np
import torch as th


def intprod(xs):
    """
    Product of a sequence of integers
    """
    out = 1
    for x in xs:
        out *= x
    return out


def safezip(*args):
    """
    Check that lengths of sequences are the same, then zip them
    """
    args = [list(a) for a in args]
    n = len(args[0])
    for arg in args[1:]:
        assert len(arg) == n, f"length mismatch: {list(map(len, args))}"
    return list(zip(*args))


def transpose(x, before, after):
    """
    Usage: x_bca = transpose(x_abc, 'abc', 'bca')
    """
    assert sorted(before) == sorted(after), f"cannot transpose {before} to {after}"
    assert x.ndim == len(
        before
    ), f"before spec '{before}' has length {len(before)} but x has {x.ndim} dimensions: {tuple(x.shape)}"
    return x.permute(tuple(before.index(i) for i in after))


def transpose_undo(x, before, after, *, undo=None):
    """
    Usage:
    x_bca, undo = transpose_undo(x_abc, 'abc', 'bca')
    x_bca = fully_connected_layer(x_bca)
    x_abc = undo(x_bca)
    """
    return (
        transpose(x, before, after),
        compose_undo(undo, lambda x: transpose(x, before=after, after=before)),
    )


def compose_undo(u1, u2):
    assert u2 is not None
    if u1 is None:
        return u2

    def u(x):
        x = u2(x)
        x = u1(x)
        return x

    return u


NO_BIND = "__nobind"


def _parse_reshape_str(s, kind):
    assert kind in ("before", "after")
    result = []
    n_underscores = 0
    for i, part in enumerate(s.split(",")):
        part = part.strip()
        if part == "?" and kind == "before":
            result.append([f"__{i}"])
        elif part == "_":
            result.append([f"{NO_BIND}_{n_underscores}"])
            n_underscores += 1
        else:
            result.append([term.strip() for term in part.split("*")])
    return result


def _infer_part(part, concrete_dim, known, index, full_shape):
    if type(part) is int:
        return part
    assert isinstance(part, list), part
    lits = []
    syms = []
    for term in part:
        if type(term) is int:
            lits.append(term)
        elif type(term) is str:
            syms.append(term)
        else:
            raise TypeError(f"got {type(term)} but expected int or str")
    int_part = 1
    for x in lits:
        int_part *= x
    if len(syms) == 0:
        return int_part
    elif len(syms) == 1 and concrete_dim is not None:
        assert concrete_dim % int_part == 0, f"{concrete_dim} % {int_part} != 0 (at index {index}, full shape is {full_shape})"
        v = concrete_dim // int_part
        if syms[0] in known:
            assert (
                known[syms[0]] == v
            ), f"known value for {syms[0]} is {known[syms[0]]} but found value {v} at index {index} (full shape is {full_shape})"
        else:
            known[syms[0]] = v
        return concrete_dim
    else:
        for i in range(len(syms)):
            if syms[i] in known:
                syms[i] = known[syms[i]]
            else:
                try:
                    syms[i] = int(syms[i])
                except ValueError:
                    pass
        return lits + syms


def _infer_step(args):
    known, desc, shape = args
    new_known = known.copy()
    new_desc = desc.copy()
    for i in range(len(desc)):
        if shape is None:
            concrete_dim = None
        else:
            concrete_dim = shape[i]
        new_desc[i] = _infer_part(part=desc[i], concrete_dim=concrete_dim, known=new_known, index=i, full_shape=shape)
    return new_known, new_desc, shape


def _infer(known, desc, shape):
    if shape is not None:
        assert len(desc) == len(shape), f"desc has length {len(desc)} but shape has length {len(shape)} (shape={shape})"
    known, desc, shape = fixed_point(_infer_step, (known, desc, shape))
    return desc, known


def fixed_point(f, x, eq=None):
    if eq is None:
        eq = lambda a, b: a == b
    while True:
        new_x = f(x)
        if eq(x, new_x):
            return x
        else:
            x = new_x


def _infer_question_mark(x, total_product):
    try:
        question_mark_index = x.index(["?"])
    except ValueError:
        return x
    observed_product = 1
    for i in range(len(x)):
        if i != question_mark_index:
            assert type(x[i]) is int, f"when there is a question mark, there can be no other unknown values (full list: {x})"
            observed_product *= x[i]
    assert (
        observed_product and total_product % observed_product == 0
    ), f"{total_product} is not divisible by {observed_product}"
    value = total_product // observed_product
    x = x.copy()
    x[question_mark_index] = value
    return x


def _ground(x, known, infer_question_mark_with=None):
    x, known = _infer(known=known, desc=x, shape=None)
    if infer_question_mark_with:
        x = _infer_question_mark(x, infer_question_mark_with)
    for part in x:
        assert type(part) is int, f"cannot infer value of {part}"
    return x


def _handle_ellipsis(x, before, after):
    ell = ["..."]
    try:
        i = before.index(ell)
        l = len(x.shape) - len(before) + 1
        ellipsis_value = x.shape[i : i + l]
        ellipsis_value = list(ellipsis_value)
        before = before[:i] + ellipsis_value + before[i + 1 :]
    except ValueError:
        pass
    try:
        i = after.index(ell)
        after = after[:i] + ellipsis_value + after[i + 1 :]
    except ValueError:
        pass
    except UnboundLocalError as e:
        raise ValueError("there cannot be an ellipsis in 'after' unless there is an ellipsis in 'before'") from e
    return before, after


def reshape_undo(inp, before, after, *, undo=None, known=None, **kwargs):
    """
    Usage:
    x_Bhwse, undo = reshape_undo(
        x_bthwe,
        'b, t, ..., stride*e',
        'b*t, ..., stride, e',
        stride=7
    )
    x_Bhwse = do_some_stuff(x_Bhwse)
    x_bthwe = undo(x_Bhwse)

    It's necessary to pass known values as keywords only
    when they can't be inferred from the shape.

    (Eg. in the above example we needed to pass
    stride but not b, t, or e, since those can be determined from
    inp.shape once stride is known.)
    """
    if known:
        known = {**kwargs, **known}
    else:
        known = kwargs
    assert type(before) is type(after), f"{type(before)} != {type(after)}"
    assert isinstance(inp, (th.Tensor, np.ndarray)), f"require tensor or ndarray but got {type(inp)}"
    assert isinstance(before, (str, list)), f"require str or list but got {type(before)}"
    if isinstance(before, str):
        before = _parse_reshape_str(before, "before")
        after = _parse_reshape_str(after, "after")
        before, after = _handle_ellipsis(inp, before, after)
    before_saved, after_saved = before, after
    before, known = _infer(known=known, desc=before, shape=inp.shape)
    before = _ground(before, known, product(inp.shape))
    after = _ground(after, known, product(inp.shape))
    known = {k: v for k, v in known.items() if not k.startswith(NO_BIND)}
    assert tuple(inp.shape) == tuple(before), f"expected shape {before} but got shape {inp.shape}"
    assert product(inp.shape) == product(
        after
    ), f"cannot reshape {inp.shape} to {after} because the number of elements does not match"
    return (
        inp.reshape(after),
        compose_undo(undo, lambda inp: reshape(inp, after_saved, before_saved, known=known)),
    )


def reshape(*args, **kwargs):
    """
    Please see the documenation for reshape_undo.
    """
    x, _ = reshape_undo(*args, **kwargs)
    return x


def product(xs, one=1):
    result = one
    for x in xs:
        result = result * x
    return result


def exact_div(a, b):
    assert a % b == 0, f"{a} is not divisible by {b}"
    return a // b


================================================
FILE: metrics/IDM/lib/mlp.py
================================================
import torch as th
from torch import nn

from lib import misc
from lib import torch_util as tu


class MLP(nn.Module):
    def __init__(self, insize, nhidlayer, outsize, hidsize, hidactiv, dtype=th.float32):
        super().__init__()
        self.insize = insize
        self.nhidlayer = nhidlayer
        self.outsize = outsize
        in_sizes = [insize] + [hidsize] * nhidlayer
        out_sizes = [hidsize] * nhidlayer + [outsize]
        self.layers = nn.ModuleList(
            [tu.NormedLinear(insize, outsize, dtype=dtype) for (insize, outsize) in misc.safezip(in_sizes, out_sizes)]
        )
        self.hidactiv = hidactiv

    def forward(self, x):
        *hidlayers, finallayer = self.layers
        for layer in hidlayers:
            x = layer(x)
            x = self.hidactiv(x)
        x = finallayer(x)
        return x

    @property
    def output_shape(self):
        return (self.outsize,)


================================================
FILE: metrics/IDM/lib/normalize_ewma.py
================================================
import numpy as np
import torch
import torch.nn as nn


class NormalizeEwma(nn.Module):
    """Normalize a vector of observations - across the first norm_axes dimensions"""

    def __init__(self, input_shape, norm_axes=2, beta=0.99999, per_element_update=False, epsilon=1e-5):
        super().__init__()

        self.input_shape = input_shape
        self.norm_axes = norm_axes
        self.epsilon = epsilon
        self.beta = beta
        self.per_element_update = per_element_update

        self.running_mean = nn.Parameter(torch.zeros(input_shape, dtype=torch.float), requires_grad=False)
        self.running_mean_sq = nn.Parameter(torch.zeros(input_shape, dtype=torch.float), requires_grad=False)
        self.debiasing_term = nn.Parameter(torch.tensor(0.0, dtype=torch.float), requires_grad=False)

    def reset_parameters(self):
        self.running_mean.zero_()
        self.running_mean_sq.zero_()
        self.debiasing_term.zero_()

    def running_mean_var(self):
        debiased_mean = self.running_mean / self.debiasing_term.clamp(min=self.epsilon)
        debiased_mean_sq = self.running_mean_sq / self.debiasing_term.clamp(min=self.epsilon)
        debiased_var = (debiased_mean_sq - debiased_mean ** 2).clamp(min=1e-2)
        return debiased_mean, debiased_var

    def forward(self, input_vector):
        # Make sure input is float32
        input_vector = input_vector.to(torch.float)

        if self.training:
            # Detach input before adding it to running means to avoid backpropping through it on
            # subsequent batches.
            detached_input = input_vector.detach()
            batch_mean = detached_input.mean(dim=tuple(range(self.norm_axes)))
            batch_sq_mean = (detached_input ** 2).mean(dim=tuple(range(self.norm_axes)))

            if self.per_element_update:
                batch_size = np.prod(detached_input.size()[: self.norm_axes])
                weight = self.beta ** batch_size
            else:
                weight = self.beta

            self.running_mean.mul_(weight).add_(batch_mean * (1.0 - weight))
            self.running_mean_sq.mul_(weight).add_(batch_sq_mean * (1.0 - weight))
            self.debiasing_term.mul_(weight).add_(1.0 * (1.0 - weight))

        mean, var = self.running_mean_var()
        return (input_vector - mean[(None,) * self.norm_axes]) / torch.sqrt(var)[(None,) * self.norm_axes]

    def denormalize(self, input_vector):
        """Transform normalized data back into original distribution"""
        mean, var = self.running_mean_var()
        return input_vector * torch.sqrt(var)[(None,) * self.norm_axes] + mean[(None,) * self.norm_axes]


================================================
FILE: metrics/IDM/lib/policy.py
================================================
from copy import deepcopy
from email import policy
from typing import Dict, Optional

import numpy as np
import torch as th
from gym3.types import DictType
from torch import nn
from torch.nn import functional as F

from lib.action_head import make_action_head
from lib.action_mapping import CameraHierarchicalMapping
from lib.impala_cnn import ImpalaCNN
from lib.normalize_ewma import NormalizeEwma
from lib.scaled_mse_head import ScaledMSEHead
from lib.tree_util import tree_map
from lib.util import FanInInitReLULayer, ResidualRecurrentBlocks
from lib.misc import transpose


class ImgPreprocessing(nn.Module):
    """Normalize incoming images.

    :param img_statistics: remote path to npz file with a mean and std image. If specified
        normalize images using this.
    :param scale_img: If true and img_statistics not specified, scale incoming images by 1/255.
    """

    def __init__(self, img_statistics: Optional[str] = None, scale_img: bool = True):
        super().__init__()
        self.img_mean = None
        if img_statistics is not None:
            img_statistics = dict(**np.load(img_statistics))
            self.img_mean = nn.Parameter(th.Tensor(img_statistics["mean"]), requires_grad=False)
            self.img_std = nn.Parameter(th.Tensor(img_statistics["std"]), requires_grad=False)
        else:
            self.ob_scale = 255.0 if scale_img else 1.0

    def forward(self, img):
        x = img.to(dtype=th.float32)
        if self.img_mean is not None:
            x = (x - self.img_mean) / self.img_std
        else:
            x = x / self.ob_scale
        return x


class ImgObsProcess(nn.Module):
    """ImpalaCNN followed by a linear layer.

    :param cnn_outsize: impala output dimension
    :param output_size: output size of the linear layer.
    :param dense_init_norm_kwargs: kwargs for linear FanInInitReLULayer
    :param init_norm_kwargs: kwargs for 2d and 3d conv FanInInitReLULayer
    """

    def __init__(
        self,
        cnn_outsize: int,
        output_size: int,
        dense_init_norm_kwargs: Dict = {},
        init_norm_kwargs: Dict = {},
        **kwargs,
    ):
        super().__init__()
        self.cnn = ImpalaCNN(
            outsize=cnn_outsize,
            init_norm_kwargs=init_norm_kwargs,
            dense_init_norm_kwargs=dense_init_norm_kwargs,
            **kwargs,
        )
        self.linear = FanInInitReLULayer(
            cnn_outsize,
            output_size,
            layer_type="linear",
            **dense_init_norm_kwargs,
        )

    def forward(self, img):
        return self.linear(self.cnn(img))


class MinecraftPolicy(nn.Module):
    """
    :param recurrence_type:
        None                - No recurrence, adds no extra layers
        lstm                - (Depreciated). Singular LSTM
        multi_layer_lstm    - Multi-layer LSTM. Uses n_recurrence_layers to determine number of consecututive LSTMs
            Does NOT support ragged batching
        multi_masked_lstm   - Multi-layer LSTM that supports ragged batching via the first vector. This model is slower
            Uses n_recurrence_layers to determine number of consecututive LSTMs
        transformer         - Dense transformer
    :param init_norm_kwargs: kwargs for all FanInInitReLULayers.
    """

    def __init__(
        self,
        recurrence_type="lstm",
        impala_width=1,
        impala_chans=(16, 32, 32),
        obs_processing_width=256,
        hidsize=512,
        single_output=False,  # True if we don't need separate outputs for action/value outputs
        img_shape=None,
        scale_input_img=True,
        only_img_input=False,
        init_norm_kwargs={},
        impala_kwargs={},
        # Unused argument assumed by forc.
        input_shape=None,  # pylint: disable=unused-argument
        active_reward_monitors=None,
        img_statistics=None,
        first_conv_norm=False,
        diff_mlp_embedding=False,
        attention_mask_style="clipped_causal",
        attention_heads=8,
        attention_memory_size=2048,
        use_pointwise_layer=True,
        pointwise_ratio=4,
        pointwise_use_activation=False,
        n_recurrence_layers=1,
        recurrence_is_residual=True,
        timesteps=None,
        use_pre_lstm_ln=True,  # Not needed for transformer
        **unused_kwargs,
    ):
        super().__init__()
        assert recurrence_type in [
            "multi_layer_lstm",
            "multi_layer_bilstm",
            "multi_masked_lstm",
            "transformer",
            "none",
        ]

        active_reward_monitors = active_reward_monitors or {}

        self.single_output = single_output

        chans = tuple(int(impala_width * c) for c in impala_chans)
        self.hidsize = hidsize

        # Dense init kwargs replaces batchnorm/groupnorm with layernorm
        self.init_norm_kwargs = init_norm_kwargs
        self.dense_init_norm_kwargs = deepcopy(init_norm_kwargs)
        if self.dense_init_norm_kwargs.get("group_norm_groups", None) is not None:
            self.dense_init_norm_kwargs.pop("group_norm_groups", None)
            self.dense_init_norm_kwargs["layer_norm"] = True
        if self.dense_init_norm_kwargs.get("batch_norm", False):
            self.dense_init_norm_kwargs.pop("batch_norm", False)
            self.dense_init_norm_kwargs["layer_norm"] = True

        # Setup inputs
        self.img_preprocess = ImgPreprocessing(img_statistics=img_statistics, scale_img=scale_input_img)
        self.img_process = ImgObsProcess(
            cnn_outsize=256,
            output_size=hidsize,
            inshape=img_shape,
            chans=chans,
            nblock=2,
            dense_init_norm_kwargs=self.dense_init_norm_kwargs,
            init_norm_kwargs=init_norm_kwargs,
            first_conv_norm=first_conv_norm,
            **impala_kwargs,
        )

        self.pre_lstm_ln = nn.LayerNorm(hidsize) if use_pre_lstm_ln else None
        self.diff_obs_process = None

        self.recurrence_type = recurrence_type

        self.recurrent_layer = None
        self.recurrent_layer = ResidualRecurrentBlocks(
            hidsize=hidsize,
            timesteps=timesteps,
            recurrence_type=recurrence_type,
            is_residual=recurrence_is_residual,
            use_pointwise_layer=use_pointwise_layer,
            pointwise_ratio=pointwise_ratio,
            pointwise_use_activation=pointwise_use_activation,
            attention_mask_style=attention_mask_style,
            attention_heads=attention_heads,
            attention_memory_size=attention_memory_size,
            n_block=n_recurrence_layers,
        )

        self.lastlayer = FanInInitReLULayer(hidsize, hidsize, layer_type="linear", **self.dense_init_norm_kwargs)
        self.final_ln = th.nn.LayerNorm(hidsize)

    def output_latent_size(self):
        return self.hidsize

    def forward(self, ob, state_in, context):
        first = context["first"]

        x = self.img_preprocess(ob["img"])
        x = self.img_process(x)

        if self.diff_obs_process:
            processed_obs = self.diff_obs_process(ob["diff_goal"])
            x = processed_obs + x

        if self.pre_lstm_ln is not None:
            x = self.pre_lstm_ln(x)

        if self.recurrent_layer is not None:
            x, state_out = self.recurrent_layer(x, first, state_in)
        else:
            state_out = state_in

        x = F.relu(x, inplace=False)

        x = self.lastlayer(x)
        x = self.final_ln(x)
        pi_latent = vf_latent = x
        if self.single_output:
            return pi_latent, state_out
        return (pi_latent, vf_latent), state_out

    def initial_state(self, batchsize):
        if self.recurrent_layer:
            return self.recurrent_layer.initial_state(batchsize)
        else:
            return None


class MinecraftAgentPolicy(nn.Module):
    def __init__(self, action_space, policy_kwargs, pi_head_kwargs):
        super().__init__()
        self.net = MinecraftPolicy(**policy_kwargs)

        self.action_space = action_space

        self.value_head = self.make_value_head(self.net.output_latent_size())
        self.pi_head = self.make_action_head(self.net.output_latent_size(), **pi_head_kwargs)

    def make_value_head(self, v_out_size: int, norm_type: str = "ewma", norm_kwargs: Optional[Dict] = None):
        return ScaledMSEHead(v_out_size, 1, norm_type=norm_type, norm_kwargs=norm_kwargs)

    def make_action_head(self, pi_out_size: int, **pi_head_opts):
        return make_action_head(self.action_space, pi_out_size, **pi_head_opts)

    def initial_state(self, batch_size: int):
        return self.net.initial_state(batch_size)

    def reset_parameters(self):
        super().reset_parameters()
        self.net.reset_parameters()
        self.pi_head.reset_parameters()
        self.value_head.reset_parameters()

    def forward(self, obs, first: th.Tensor, state_in):
        if isinstance(obs, dict):
            # We don't want to mutate the obs input.
            obs = obs.copy()

            # If special "mask" key is in obs,
            # It's for masking the logits.
            # We take it out (the network doesn't need it)
            mask = obs.pop("mask", None)
        else:
            mask = None

        (pi_h, v_h), state_out = self.net(obs, state_in, context={"first": first})

        pi_logits = self.pi_head(pi_h, mask=mask)
        vpred = self.value_head(v_h)

        return (pi_logits, vpred, None), state_out

    def get_logprob_of_action(self, pd, action):
        """
        Get logprob of taking action `action` given probability distribution
        (see `get_gradient_for_action` to get this distribution)
        """
        ac = tree_map(lambda x: x.unsqueeze(1), action)
        log_prob = self.pi_head.logprob(ac, pd)
        assert not th.isnan(log_prob).any()
        return log_prob[:, 0]

    def get_kl_of_action_dists(self, pd1, pd2):
        """
        Get the KL divergence between two action probability distributions
        """
        return self.pi_head.kl_divergence(pd1, pd2)

    def get_output_for_observation(self, obs, state_in, first):
        """
        Return gradient-enabled outputs for given observation.

        Use `get_logprob_of_action` to get log probability of action
        with the given probability distribution.

        Returns:
          - probability distribution given observation
          - value prediction for given observation
          - new state
        """
        # We need to add a fictitious time dimension everywhere
        obs = tree_map(lambda x: x.unsqueeze(1), obs)
        first = first.unsqueeze(1)

        (pd, vpred, _), state_out = self(obs=obs, first=first, state_in=state_in)

        return pd, self.value_head.denormalize(vpred)[:, 0], state_out

    @th.no_grad()
    def act(self, obs, first, state_in, stochastic: bool = True, taken_action=None, return_pd=False):
        # We need to add a fictitious time dimension everywhere
        obs = tree_map(lambda x: x.unsqueeze(1), obs)
        first = first.unsqueeze(1)

        (pd, vpred, _), state_out = self(obs=obs, first=first, state_in=state_in)

        if taken_action is None:
            ac = self.pi_head.sample(pd, deterministic=not stochastic)
        else:
            ac = tree_map(lambda x: x.unsqueeze(1), taken_action)
        log_prob = self.pi_head.logprob(ac, pd)
        assert not th.isnan(log_prob).any()

        # After unsqueezing, squeeze back to remove fictitious time dimension
        result = {"log_prob": log_prob[:, 0], "vpred": self.value_head.denormalize(vpred)[:, 0]}
        if return_pd:
            result["pd"] = tree_map(lambda x: x[:, 0], pd)
        ac = tree_map(lambda x: x[:, 0], ac)

        return ac, state_out, result

    @th.no_grad()
    def v(self, obs, first, state_in):
        """Predict value for a given mdp observation"""
        obs = tree_map(lambda x: x.unsqueeze(1), obs)
        first = first.unsqueeze(1)

        (pd, vpred, _), state_out = self(obs=obs, first=first, state_in=state_in)

        # After unsqueezing, squeeze back
        return self.value_head.denormalize(vpred)[:, 0]


class InverseActionNet(MinecraftPolicy):
    """
    Args:
        conv3d_params: PRE impala 3D CNN params. They are just passed into th.nn.Conv3D.
    """

    def __init__(
        self,
        hidsize=512,
        conv3d_params=None,
        **MCPoliy_kwargs,
    ):
        super().__init__(
            hidsize=hidsize,
            # If we're using 3dconv, then we normalize entire impala otherwise don't
            # normalize the first impala layer since we normalize the input
            first_conv_norm=conv3d_params is not None,
            **MCPoliy_kwargs,
        )
        self.conv3d_layer = None
        if conv3d_params is not None:
            # 3D conv is the first layer, so don't normalize its input
            conv3d_init_params = deepcopy(self.init_norm_kwargs)
            conv3d_init_params["group_norm_groups"] = None
            conv3d_init_params["batch_norm"] = False
            self.conv3d_layer = FanInInitReLULayer(
                layer_type="conv3d",
                log_scope="3d_conv",
                **conv3d_params,
                **conv3d_init_params,
            )

    def forward(self, ob, state_in, context):
        first = context["first"]
        x = self.img_preprocess(ob["img"])

        # Conv3D Prior to Impala
        if self.conv3d_layer is not None:
            x = self._conv3d_forward(x)

        # Impala Stack
        x = self.img_process(x)

        if self.recurrent_layer is not None:
            x, state_out = self.recurrent_layer(x, first, state_in)

        x = F.relu(x, inplace=False)

        pi_latent = self.lastlayer(x)
        pi_latent = self.final_ln(x)
        return (pi_latent, None), state_out

    def _conv3d_forward(self, x):
        # Convert from (B, T, H, W, C) -> (B, H, W, C, T)
        x = transpose(x, "bthwc", "bcthw")
        new_x = []
        for mini_batch in th.split(x, 1):
            new_x.append(self.conv3d_layer(mini_batch))
        x = th.cat(new_x)
        # Convert back
        x = transpose(x, "bcthw", "bthwc")
        return x


class InverseActionPolicy(nn.Module):
    def __init__(
        self,
        action_space,
        pi_head_kwargs=None,
        idm_net_kwargs=None,
    ):
        super().__init__()
        self.action_space = action_space

        self.net = InverseActionNet(**idm_net_kwargs)

        pi_out_size = self.net.output_latent_size()

        pi_head_kwargs = {} if pi_head_kwargs is None else pi_head_kwargs

        self.pi_head = self.make_action_head(pi_out_size=pi_out_size, **pi_head_kwargs)

    def make_action_head(self, **kwargs):
        return make_action_head(self.action_space, **kwargs)

    def reset_parameters(self):
        super().reset_parameters()
        self.net.reset_parameters()
        self.pi_head.reset_parameters()

    def forward(self, obs, first: th.Tensor, state_in, **kwargs):
        if isinstance(obs, dict):
            # We don't want to mutate the obs input.
            obs = obs.copy()

            # If special "mask" key is in obs,
            # It's for masking the logits.
            # We take it out (the network doesn't need it)
            mask = obs.pop("mask", None)
        else:
            mask = None

        (pi_h, _), state_out = self.net(obs, state_in=state_in, context={"first": first}, **kwargs)
        pi_logits = self.pi_head(pi_h, mask=mask)
        return (pi_logits, None, None), state_out

    @th.no_grad()
    def predict(
        self,
        obs,
        deterministic: bool = True,
        **kwargs,
    ):
        (pd, _, _), state_out = self(obs=obs, **kwargs)

        ac = self.pi_head.sample(pd, deterministic=deterministic)
        log_prob = self.pi_head.logprob(ac, pd)

        assert not th.isnan(log_prob).any()

        result = {"log_prob": log_prob, "pd": pd}

        return ac, state_out, result

    def initial_state(self, batch_size: int):
        return self.net.initial_state(batch_size)


================================================
FILE: metrics/IDM/lib/scaled_mse_head.py
================================================
from typing import Dict, Optional

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from lib.action_head import fan_in_linear
from lib.normalize_ewma import NormalizeEwma


class ScaledMSEHead(nn.Module):
    """
    Linear output layer that scales itself so that targets are always normalized to N(0, 1)
    """

    def __init__(
        self, input_size: int, output_size: int, norm_type: Optional[str] = "ewma", norm_kwargs: Optional[Dict] = None
    ):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.norm_type = norm_type

        self.linear = nn.Linear(self.input_size, self.output_size)

        norm_kwargs = {} if norm_kwargs is None else norm_kwargs
        self.normalizer = NormalizeEwma(output_size, **norm_kwargs)

    def reset_parameters(self):
        init.orthogonal_(self.linear.weight)
        fan_in_linear(self.linear)
        self.normalizer.reset_parameters()

    def forward(self, input_data):
        return self.linear(input_data)

    def loss(self, prediction, target):
        """
        Calculate the MSE loss between output and a target.
        'Prediction' has to be normalized while target is denormalized.
        Loss is calculated in a 'normalized' space.
        """
        return F.mse_loss(prediction, self.normalizer(target), reduction="mean")

    def denormalize(self, input_data):
        """Convert input value from a normalized space into the original one"""
        return self.normalizer.denormalize(input_data)

    def normalize(self, input_data):
        return self.normalizer(input_data)


================================================
FILE: metrics/IDM/lib/torch_util.py
================================================
import functools
import itertools
import math
import os
import pickle
import re
import subprocess
import tempfile
from contextlib import contextmanager
from hashlib import md5, sha1

import numpy as np
import torch as th
import torch.distributed as dist
import torch.distributions as dis
import torch.nn.functional as F
from torch import nn

import lib.tree_util as tree_util
from lib import misc


def contextmanager_to_decorator(cm):
    def decorator(fn):
        @functools.wraps(fn)
        def newfn(*args, **kwargs):
            with cm():
                return fn(*args, **kwargs)

        return newfn

    return decorator


def have_cuda():
    return th.has_cuda


def default_device_type():
    return "cuda" if have_cuda() else "cpu"


no_grad = contextmanager_to_decorator(th.no_grad)
DEFAULT_DEVICE = th.device(type=default_device_type())


def set_default_torch_device(device):
    global DEFAULT_DEVICE
    DEFAULT_DEVICE = th.device(device)


def dev():
    return DEFAULT_DEVICE


def zeros(*args, **kwargs):
    return th.zeros(*args, **kwargs, device=dev())


def ones(*args, **kwargs):
    return th.ones(*args, **kwargs, device=dev())


def arange(*args, **kwargs):
    return th.arange(*args, **kwargs, device=dev())


def NormedLinear(*args, scale=1.0, dtype=th.float32, **kwargs):
    """
    nn.Linear but with normalized fan-in init
    """
    dtype = parse_dtype(dtype)
    if dtype == th.float32:
        out = nn.Linear(*args, **kwargs)
    elif dtype == th.float16:
        out = LinearF16(*args, **kwargs)
    else:
        raise ValueError(dtype)
    out.weight.data *= scale / out.weight.norm(dim=1, p=2, keepdim=True)
    if kwargs.get("bias", True):
        out.bias.data *= 0
    return out


class LinearF16(nn.Linear):
    def forward(self, x):
        return F.linear(x, self.weight.half(), self.bias.half() if self.bias is not None else None)


class LayerNormF16(nn.LayerNorm):
    def forward(self, x):
        return F.layer_norm(x, self.normalized_shape, self.weight.half(), self.bias.half(), self.eps)


def LayerNorm(*args, dtype=th.float32, **kwargs):
    dtype = parse_dtype(dtype)
    if dtype == th.float32:
        out = nn.LayerNorm(*args, **kwargs)
    elif dtype == th.float16:
        out = LayerNormF16(*args, **kwargs)
    else:
        raise ValueError(dtype)
    out.weight.no_scale = True
    return out


def flatten_image(x):
    """
    Flattens last three dims
    """
    *batch_shape, h, w, c = x.shape
    return x.reshape((*batch_shape, h * w * c))


def sequential(layers, x, *args, diag_name=None, use_checkpoint=False):
    for (i, layer) in enumerate(layers):
        x = layer(x, *args)
    return x


@no_grad
def load_average_with_metadata(paths, overrides):
    n_models = len(paths)
    model, metadata = load_with_metadata(paths[0], overrides=overrides)
    for p in model.parameters():
        p.mul_(1 / n_models)
    for p in paths[1:]:
        new_model, _ = load_with_metadata(p, overrides=overrides)
        for (n1, p1), (n2, p2) in misc.safezip(model.named_parameters(), new_model.named_parameters()):
            assert n1 == n2, f"names {n1} and {n2} don't match"
            p1.add_(p2.mul_(1 / n_models))
    return model, metadata


def save_kwargs(fn):
    """
    This decorator passes through the user-provided kwargs and adds one more, called
    save_kwargs, mapping to {"create_fn" : name_of_decorated_fn, "kwargs" : other_kwargs}

    You put on this decorator on a function that creates a pytorch module. This will
    save the kwargs and the function that was used to create the module.
    This lets us restore the model state later.
    """

    @functools.wraps(fn)
    def wrapper(**kwargs):
        if "save_kwargs" in kwargs:
            return fn(**kwargs)
        else:
            sk = {**kwargs, "create_fn": f"{fn.__module__}:{fn.__name__}"}
            return fn(save_kwargs=sk, **kwargs)

    return wrapper


def parse_dtype(x):
    if isinstance(x, th.dtype):
        return x
    elif isinstance(x, str):
        if x == "float32" or x == "float":
            return th.float32
        elif x == "float64" or x == "double":
            return th.float64
        elif x == "float16" or x == "half":
            return th.float16
        elif x == "uint8":
            return th.uint8
        elif x == "int8":
            return th.int8
        elif x == "int16" or x == "short":
            return th.int16
        elif x == "int32" or x == "int":
            return th.int32
        elif x == "int64" or x == "long":
            return th.int64
        elif x == "bool":
            return th.bool
        else:
            raise ValueError(f"cannot parse {x} as a dtype")
    else:
        raise TypeError(f"cannot parse {type(x)} as dtype")


def index(x, i):
    """
    Batched, broadcasting index of x along dimension i.ndim.

    For example, if x has shape (1, 2, 3, 4, 5) and i has shape (1, 1, 3)
    then the result has shape (1, 2, 3, 5) and each value in i must be between 0 and 3.
    """
    assert x.ndim >= i.ndim + 1
    gather_dim = i.ndim
    while i.ndim < x.ndim:
        i = i.unsqueeze(-1)
    expand_shape = list(x.shape)
    expand_shape[gather_dim] = 1
    i = i.expand(*expand_shape)
    xi = th.gather(x, gather_dim, i)
    assert xi.shape[gather_dim] == 1
    return xi.squeeze(gather_dim)


================================================
FILE: metrics/IDM/lib/tree_util.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copied this from jax, made it self-contained
# Currently just used for improved_checkpoint

import collections
import functools
import itertools as it
from collections.abc import Collection
from typing import Dict, List, Optional


def unzip2(xys):
    xs = []
    ys = []
    for x, y in xys:
        xs.append(x)
        ys.append(y)
    return tuple(xs), tuple(ys)


def partial(fun, *args, **kwargs):
    wrapped = functools.partial(fun, *args, **kwargs)
    functools.update_wrapper(wrapped, fun)
    wrapped._bound_args = args  # pylint: disable=protected-access
    return wrapped


def safe_zip(*args: Collection) -> List[tuple]:
    n = len(args[0])
    for arg in args[1:]:
        assert len(arg) == n, "length mismatch: {}".format(list(map(len, args)))
    return list(zip(*args))


def safe_map(f, *args):
    args = list(map(list, args))
    n = len(args[0])
    for arg in args[1:]:
        assert len(arg) == n, "length mismatch: {}".format(list(map(len, args)))
    return list(map(f, *args))


def tree_map(f, tree, treat_as_leaves: Optional[List] = None):
    """Map a function over a pytree to produce a new pytree.

    Args:
      f: function to be applied at each leaf.
      tree: a pytree to be mapped over.

    Returns:
      A new pytree with the same structure as `tree` but with the value at each
      leaf given by `f(x)` where `x` is the value at the corresponding leaf in
      `tree`.
    """
    if treat_as_leaves is None:
        treat_as_leaves = []
    node_type = node_types.get(type(tree))
    if node_type and type(tree) not in treat_as_leaves:
        children, node_spec = node_type.to_iterable(tree)
        new_children = [tree_map(f, child, treat_as_leaves) for child in children]
        return node_type.from_iterable(node_spec, new_children)
    else:
        return f(tree)


def tree_multimap(f, tree, *rest, treat_as_leaves: Optional[List] = None):
    """Map a multi-input function over pytree args to produce a new pytree.

    Args:
      f: function that takes `1 + len(rest)` arguments, to be applied at the
        corresponding leaves of the pytrees.
      tree: a pytree to be mapped over, with each leaf providing the first
        positional argument to `f`.
      *rest: a tuple of pytrees, each with the same structure as `tree`.

    Returns:
      A new pytree with the same structure as `tree` but with the value at each
      leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf
      in `tree` and `xs` is the tuple of values at corresponding leaves in `rest`.
    """

    if treat_as_leaves is None:
        treat_as_leaves = []
    node_type = node_types.get(type(tree))
    if node_type and type(tree) not in treat_as_leaves:
        children, node_spec = node_type.to_iterable(tree)
        all_children = [children]
        for other_tree in rest:
            other_children, other_node_data = node_type.to_iterable(other_tree)
            if other_node_data != node_spec:
                raise TypeError("Mismatch: {} != {}".format(other_node_data, node_spec))
            all_children.append(other_children)

        new_children = [tree_multimap(f, *xs, treat_as_leaves=treat_as_leaves) for xs in zip(*all_children)]
        return node_type.from_iterable(node_spec, new_children)
    else:
        return f(tree, *rest)


def prefix_multimap(f, treedef, tree, *rest):
    """Like tree_multimap but only maps down through a tree prefix."""
    if isinstance(treedef, PyLeaf):
        return f(tree, *rest)
    else:
        node_type = node_types.get(type(tree))
        if node_type != treedef.node_type:
            raise TypeError("Mismatch: {} != {}".format(treedef.node_type, node_type))
        children, node_data = node_type.to_iterable(tree)
        if node_data != treedef.node_data:
            raise TypeError("Mismatch: {} != {}".format(treedef.node_data, node_data))
        all_children = [children]
        for other_tree in rest:
            other_children, other_node_data = node_type.to_iterable(other_tree)
            if other_node_data != node_data:
                raise TypeError("Mismatch: {} != {}".format(other_node_data, node_data))
            all_children.append(other_children)
        all_children = zip(*all_children)

        new_children = [prefix_multimap(f, td, *xs) for td, xs in zip(treedef.children, all_children)]
        return node_type.from_iterable(node_data, new_children)


def walk_pytree(f_node, f_leaf, tree, treat_as_leaves: Optional[List] = None):
    node_type = node_types.get(type(tree))
    if treat_as_leaves is None:
        treat_as_leaves = []

    if node_type and type(tree) not in treat_as_leaves:
        children, node_spec = node_type.to_iterable(tree)
        proc_children, child_specs = unzip2([walk_pytree(f_node, f_leaf, child, treat_as_leaves) for child in children])
        tree_def = PyTreeDef(node_type, node_spec, child_specs)
        return f_node(proc_children), tree_def
    else:
        return f_leaf(tree), PyLeaf()


def build_tree(treedef, xs):
    if isinstance(treedef, PyLeaf):
        return xs
    else:
        # We use 'iter' for clearer error messages
        children = safe_map(build_tree, iter(treedef.children), iter(xs))
        return treedef.node_type.from_iterable(treedef.node_data, children)


def _tree_unflatten(xs, treedef):
    if isinstance(treedef, PyLeaf):
        return next(xs)
    else:
        children = safe_map(partial(_tree_unflatten, xs), treedef.children)
        return treedef.node_type.from_iterable(treedef.node_data, children)


def _num_leaves(treedef):
    return 1 if isinstance(treedef, PyLeaf) else sum(safe_map(_num_leaves, treedef.children))


def _nested_treedef(inner, outer):
    # just used in tree_transpose error checking
    if isinstance(outer, PyLeaf):
        return inner
    else:
        children = safe_map(partial(_nested_treedef, inner), outer.children)
        return PyTreeDef(outer.node_type, outer.node_data, tuple(children))


class PyTreeDef(object):
    def __init__(self, node_type, node_data, children):
        self.node_type = node_type
        self.node_data = node_data
        self.children = children

    def __repr__(self):
        if self.node_data is None:
            data_repr = ""
        else:
            data_repr = "[{}]".format(self.node_data)

        return "PyTree({}{}, [{}])".format(self.node_type.name, data_repr, ",".join(safe_map(repr, self.children)))

    def __hash__(self):
        return hash((self.node_type, self.node_data, tuple(self.children)))

    def __eq__(self, other):
        if isinstance(other, PyLeaf):
            return False
        else:
            return self.node_type == other.node_type and self.node_data == other.node_data and self.children == other.children

    def __ne__(self, other):
        return not self == other


class PyLeaf(object):
    def __repr__(self):
        return "*"

    def __eq__(self, other):
        return isinstance(other, PyLeaf)


class NodeType(object):
    def __init__(self, name, to_iterable, from_iterable):
        self.name = name
        self.to_iterable = to_iterable
        self.from_iterable = from_iterable


node_types: Dict[type, NodeType] = {}


def register_pytree_node(py_type, to_iterable, from_iterable):
    assert py_type not in node_types
    node_types[py_type] = NodeType(str(py_type), to_iterable, from_iterable)


def tuple_to_iterable(xs):
    return xs, None


def tuple_from_iterable(_keys, xs):
    return tuple(xs)


def list_to_iterable(xs):
    return tuple(xs), None


def list_from_iterable(_keys, xs):
    return list(xs)


def dict_to_iterable(xs):
    keys = tuple(sorted(xs.keys()))
    return tuple(map(xs.get, keys)), keys


def dict_from_iterable(keys, xs):
    return dict(safe_zip(keys, xs))


def ordered_dict_from_iterable(keys, xs):
    return collections.OrderedDict(safe_zip(keys, xs))


def default_dict_to_iterable(xs):
    return (tuple(xs.values()), (xs.default_factory, tuple(xs.keys())))


def default_dict_from_iterable(keys, xs):
    return collections.defaultdict(keys[0], safe_zip(keys[1], xs))


def none_to_iterable(_xs):
    return (), None


def none_from_iterable(_keys, _xs):
    return None


register_pytree_node(tuple, tuple_to_iterable, tuple_from_iterable)
register_pytree_node(list, list_to_iterable, list_from_iterable)
register_pytree_node(dict, dict_to_iterable, dict_from_iterable)
register_pytree_node(collections.OrderedDict, dict_to_iterable, ordered_dict_from_iterable)
register_pytree_node(collections.defaultdict, default_dict_to_iterable, default_dict_from_iterable)
register_pytree_node(type(None), none_to_iterable, none_from_iterable)


================================================
FILE: metrics/IDM/lib/util.py
================================================
from typing import Dict, Optional

import torch as th
from torch import nn
from torch.nn import functional as F

import lib.torch_util as tu
from lib.masked_attention import MaskedAttention
from lib.minecraft_util import store_args
from lib.tree_util import tree_map


def get_module_log_keys_recursive(m: nn.Module):
    """Recursively get all keys that a module and its children want to log."""
    keys = []
    if hasattr(m, "get_log_keys"):
        keys += m.get_log_keys()
    for c in m.children():
        keys += get_module_log_keys_recursive(c)
    return keys


class FanInInitReLULayer(nn.Module):
    """Implements a slightly modified init that correctly produces std 1 outputs given ReLU activation
    :param inchan: number of input channels
    :param outchan: number of output channels
    :param layer_args: positional layer args
    :param layer_type: options are "linear" (dense layer), "conv" (2D Convolution), "conv3d" (3D convolution)
    :param init_scale: multiplier on initial weights
    :param batch_norm: use batch norm after the layer (for 2D data)
    :param group_norm_groups: if not None, use group norm with this many groups after the layer. Group norm 1
        would be equivalent of layernorm for 2D data.
    :param layer_norm: use layernorm after the layer (for 1D data)
    :param layer_kwargs: keyword arguments for the layer
    """

    @store_args
    def __init__(
        self,
        inchan: int,
        outchan: int,
        *layer_args,
        layer_type: str = "conv",
        init_scale: int = 1,
        batch_norm: bool = False,
        batch_norm_kwargs: Dict = {},
        group_norm_groups: Optional[int] = None,
        layer_norm: bool = False,
        use_activation=True,
        log_scope: Optional[str] = None,
        **layer_kwargs,
    ):
        super().__init__()

        # Normalization
        self.norm = None
        if batch_norm:
            self.norm = nn.BatchNorm2d(inchan, **batch_norm_kwargs)
        elif group_norm_groups is not None:
            self.norm = nn.GroupNorm(group_norm_groups, inchan)
        elif layer_norm:
            self.norm = nn.LayerNorm(inchan)

        layer = dict(conv=nn.Conv2d, conv3d=nn.Conv3d, linear=nn.Linear)[layer_type]
        self.layer = layer(inchan, outchan, bias=self.norm is None, *layer_args, **layer_kwargs)

        # Init Weights (Fan-In)
        self.layer.weight.data *= init_scale / self.layer.weight.norm(
            dim=tuple(range(1, self.layer.weight.data.ndim)), p=2, keepdim=True
        )
        # Init Bias
        if self.layer.bias is not None:
            self.layer.bias.data *= 0

    def forward(self, x):
        """Norm after the activation. Experimented with this for both IAM and BC and it was slightly better."""
        if self.norm is not None:
            x = self.norm(x)
        x = self.layer(x)
        if self.use_activation:
            x = F.relu(x, inplace=True)
        return x

    def get_log_keys(self):
        return [
            f"activation_mean/{self.log_scope}",
            f"activation_std/{self.log_scope}",
        ]


class ResidualRecurrentBlocks(nn.Module):
    @store_args
    def __init__(
        self,
        n_block=2,
        recurrence_type="multi_layer_lstm",
        is_residual=True,
        **block_kwargs,
    ):
        super().__init__()
        init_scale = n_block ** -0.5 if is_residual else 1
        self.blocks = nn.ModuleList(
            [
                ResidualRecurrentBlock(
                    **block_kwargs,
                    recurrence_type=recurrence_type,
 
Download .txt
gitextract_mrwf1vu0/

├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── configs/
│   ├── 1200M_16f.yaml
│   ├── 1200M_32f.yaml
│   ├── 300M_16f.yaml
│   ├── 700M_16f.yaml
│   └── 700M_32f.yaml
├── diagonal_decoding.py
├── inference.py
├── lvm.py
├── mcdataset.py
├── metrics/
│   ├── IDM/
│   │   ├── inverse_dynamics_model.py
│   │   └── lib/
│   │       ├── __init__.py
│   │       ├── action_head.py
│   │       ├── action_mapping.py
│   │       ├── actions.py
│   │       ├── impala_cnn.py
│   │       ├── masked_attention.py
│   │       ├── minecraft_util.py
│   │       ├── misc.py
│   │       ├── mlp.py
│   │       ├── normalize_ewma.py
│   │       ├── policy.py
│   │       ├── scaled_mse_head.py
│   │       ├── torch_util.py
│   │       ├── tree_util.py
│   │       ├── util.py
│   │       └── xf.py
│   ├── common_metrics.py
│   └── tabulate_all_results.py
├── mineworld.py
├── requirements.txt
├── scripts/
│   ├── compute_metrics.sh
│   ├── inference_16f_models.sh
│   └── setup_metrics.sh
├── utils.py
└── vae.py
Download .txt
SYMBOL INDEX (378 symbols across 25 files)

FILE: diagonal_decoding.py
  function sample_top_k (line 5) | def sample_top_k(logits, temperature: float = 1.0, top_k: Optional[int] ...
  function multinomial_sample_one_no_sync (line 18) | def multinomial_sample_one_no_sync(probs_sort, dtype=torch.int):
  function logits_to_probs (line 26) | def logits_to_probs(
  function sample_top_p (line 40) | def sample_top_p(logits, temperature, top_p, vocab_size=8192):
  function sample_n_top_p (line 51) | def sample_n_top_p(logits, temperature, top_p, vocab_size=8192):
  function sample_n_top_k (line 62) | def sample_n_top_k(logits, temperature: float = 1.0, top_k: Optional[int...
  function logits_to_n_probs (line 73) | def logits_to_n_probs(
  function decode_one_token (line 87) | def decode_one_token(
  function decode_some_token (line 104) | def decode_some_token(
  function decode_n_tokens (line 121) | def decode_n_tokens(
  function decode_n_tokens_for_gradio (line 169) | def decode_n_tokens_for_gradio(
  function prefill (line 207) | def prefill(
  function img_diagd_prepare_inputs (line 226) | def img_diagd_prepare_inputs(
  function img_diagd_decode_n_tokens (line 257) | def img_diagd_decode_n_tokens(
  function img_diagd_prepare_inputs_for_gradio (line 353) | def img_diagd_prepare_inputs_for_gradio(
  function img_diagd_decode_n_token_for_gradio (line 376) | def img_diagd_decode_n_token_for_gradio(
  function vid_diagd_prepare_inputs (line 477) | def vid_diagd_prepare_inputs(
  function video_diagd_decode_n_tokens (line 518) | def video_diagd_decode_n_tokens(

FILE: inference.py
  function token2video (line 30) | def token2video(code_list, tokenizer, save_path, fps, device = 'cuda'):
  function get_args (line 50) | def get_args():
  function lvm_generate (line 69) | def lvm_generate(args, model, output_dir, demo_video):

FILE: lvm.py
  function rotate_half (line 27) | def rotate_half(x):
  function apply_rotary_pos_emb (line 33) | def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
  class LlamaLVM (line 54) | class LlamaLVM(torch.nn.Module):
    method __init__ (line 55) | def __init__(
  class LlamaRMSNorm (line 67) | class LlamaRMSNorm(nn.Module):
    method __init__ (line 68) | def __init__(self, hidden_size, eps=1e-6):
    method forward (line 76) | def forward(self, hidden_states):
    method extra_repr (line 83) | def extra_repr(self):
  class LlamaRotaryEmbedding (line 86) | class LlamaRotaryEmbedding(nn.Module):
    method __init__ (line 87) | def __init__(
    method _set_cos_sin_cache (line 105) | def _set_cos_sin_cache(self, device, dtype):
    method forward (line 127) | def forward(self, x, seq_len=None):
  class LlamaMLP (line 146) | class LlamaMLP(nn.Module):
    method __init__ (line 147) | def __init__(self, config):
    method forward (line 157) | def forward(self, x):
  class LlamaAttention (line 161) | class LlamaAttention(nn.Module):
    method __init__ (line 162) | def __init__(self, config: LlamaConfig):
    method init_kv_cache (line 188) | def init_kv_cache(self, dtype=torch.float16):
    method forward (line 193) | def forward(
  class LlamaDecoderLayer (line 248) | class LlamaDecoderLayer(nn.Module):
    method __init__ (line 249) | def __init__(self, config: LlamaConfig):
    method forward (line 257) | def forward(
  class LlamaModel (line 299) | class LlamaModel(PreTrainedModel):
    method __init__ (line 307) | def __init__(self, config: LlamaConfig):
    method _create_attention_mask (line 323) | def _create_attention_mask(self, input_pos: Optional[torch.Tensor]):
    method forward (line 336) | def forward(
  class LlamaForCausalLM (line 366) | class LlamaForCausalLM(PreTrainedModel):
    method __init__ (line 368) | def __init__(self, config):
    method forward (line 375) | def forward(
    method refresh_kvcache (line 387) | def refresh_kvcache(self):
    method naive_generate (line 391) | def naive_generate(self, input_ids, max_new_tokens, temperature=1.0, a...
    method prefill_for_gradio (line 422) | def prefill_for_gradio(self, input_ids, temperature=1.0):
    method decode_img_token_for_gradio (line 435) | def decode_img_token_for_gradio(self, input_action, position_id, max_n...
    method diagd_img_token_for_gradio (line 452) | def diagd_img_token_for_gradio(self, input_action, position_id, max_ne...
    method img_diagd_generate (line 468) | def img_diagd_generate(self, input_ids, max_new_tokens, temperature=1....
    method vid_diagd_generate (line 500) | def vid_diagd_generate(self, input_ids, max_new_tokens,windowsize=2, t...

FILE: mcdataset.py
  class Buttons (line 102) | class Buttons:
  class QuantizationScheme (line 143) | class QuantizationScheme:
  class CameraQuantizer (line 150) | class CameraQuantizer:
    method discretize (line 183) | def discretize(self, xy):
    method undiscretize (line 195) | def undiscretize(self, xy):
  class MCDataset (line 206) | class MCDataset(torch.utils.data.Dataset):
    method __init__ (line 210) | def __init__(self,
    method json_action_to_env_action (line 225) | def json_action_to_env_action(self, json_action):
    method make_action_vocab (line 297) | def make_action_vocab(self,
    method _handle_conflict_action_index (line 325) | def _handle_conflict_action_index(self,
    method get_action_index_from_actiondict (line 342) | def get_action_index_from_actiondict(self,
    method read_jsonl (line 391) | def read_jsonl(self, jsonl_path: str):

FILE: metrics/IDM/inverse_dynamics_model.py
  function resize_image (line 23) | def resize_image(img, target_resolution):
  class IDMAgent (line 36) | class IDMAgent:
    method __init__ (line 42) | def __init__(self, idm_net_kwargs, pi_head_kwargs, device=None):
    method load_weights (line 60) | def load_weights(self, path):
    method reset (line 65) | def reset(self):
    method _video_obs_to_agent (line 69) | def _video_obs_to_agent(self, video_frames):
    method _agent_action_to_env (line 76) | def _agent_action_to_env(self, agent_action):
    method predict_actions (line 89) | def predict_actions(self, video_frames):
  function json_action_to_env_action (line 183) | def json_action_to_env_action(json_action):
  function load_action_jsonl (line 233) | def load_action_jsonl(json_path):
  function evaluate_IDM_quality (line 242) | def evaluate_IDM_quality(model, weights,jsonl_folder, video_folder, infe...
  function construct_classification_labels (line 312) | def construct_classification_labels(idm_actions:dict[str, list[int]],act...
  function define_exclusive_classification_task (line 339) | def define_exclusive_classification_task(predicted_actions:dict,recorded...
  function classification_metric (line 371) | def classification_metric(pred_labels, rec_labels, class_num):
  function aggregate_actions (line 393) | def aggregate_actions(actions:list) -> dict:
  function idm_prediction (line 404) | def idm_prediction(agent, video_path,json_path, infer_demo_num, n_frames):
  function camera_loss (line 428) | def camera_loss(predicted_actions,recorded_actions):

FILE: metrics/IDM/lib/action_head.py
  function fan_in_linear (line 14) | def fan_in_linear(module: nn.Module, scale=1.0, bias=True):
  class ActionHead (line 22) | class ActionHead(nn.Module):
    method forward (line 25) | def forward(self, input_data: torch.Tensor) -> Any:
    method logprob (line 32) | def logprob(self, action_sample: torch.Tensor, pd_params: torch.Tensor...
    method entropy (line 36) | def entropy(self, pd_params: torch.Tensor) -> torch.Tensor:
    method sample (line 40) | def sample(self, pd_params: torch.Tensor, deterministic: bool = False)...
    method kl_divergence (line 49) | def kl_divergence(self, params_q: torch.Tensor, params_p: torch.Tensor...
  class DiagGaussianActionHead (line 54) | class DiagGaussianActionHead(ActionHead):
    method __init__ (line 63) | def __init__(self, input_dim: int, num_dimensions: int):
    method reset_parameters (line 72) | def reset_parameters(self):
    method forward (line 76) | def forward(self, input_data: torch.Tensor, mask=None) -> torch.Tensor:
    method logprob (line 86) | def logprob(self, action_sample: torch.Tensor, pd_params: torch.Tensor...
    method entropy (line 97) | def entropy(self, pd_params: torch.Tensor) -> torch.Tensor:
    method sample (line 105) | def sample(self, pd_params: torch.Tensor, deterministic: bool = False)...
    method kl_divergence (line 114) | def kl_divergence(self, params_q: torch.Tensor, params_p: torch.Tensor...
  class CategoricalActionHead (line 136) | class CategoricalActionHead(ActionHead):
    method __init__ (line 139) | def __init__(
    method reset_parameters (line 157) | def reset_parameters(self):
    method forward (line 163) | def forward(self, input_data: torch.Tensor, mask=None) -> Any:
    method logprob (line 176) | def logprob(self, actions: torch.Tensor, logits: torch.Tensor) -> torc...
    method entropy (line 186) | def entropy(self, logits: torch.Tensor) -> torch.Tensor:
    method sample (line 195) | def sample(self, logits: torch.Tensor, deterministic: bool = False) ->...
    method kl_divergence (line 209) | def kl_divergence(self, logits_q: torch.Tensor, logits_p: torch.Tensor...
  class DictActionHead (line 223) | class DictActionHead(nn.ModuleDict):
    method reset_parameters (line 226) | def reset_parameters(self):
    method forward (line 230) | def forward(self, input_data: torch.Tensor, **kwargs) -> Any:
    method logprob (line 250) | def logprob(self, actions: torch.Tensor, logits: torch.Tensor) -> torc...
    method sample (line 253) | def sample(self, logits: torch.Tensor, deterministic: bool = False) ->...
    method entropy (line 256) | def entropy(self, logits: torch.Tensor) -> torch.Tensor:
    method kl_divergence (line 259) | def kl_divergence(self, logits_q: torch.Tensor, logits_p: torch.Tensor...
  function make_action_head (line 263) | def make_action_head(ac_space: ValType, pi_out_size: int, temperature: f...

FILE: metrics/IDM/lib/action_mapping.py
  class ActionMapping (line 12) | class ActionMapping(abc.ABC):
    method __init__ (line 30) | def __init__(self, n_camera_bins: int = 11):
    method from_factored (line 42) | def from_factored(self, ac: Dict) -> Dict:
    method to_factored (line 50) | def to_factored(self, ac: Dict) -> Dict:
    method get_action_space_update (line 58) | def get_action_space_update(self):
    method get_zero_action (line 63) | def get_zero_action(self):
    method factored_buttons_to_groups (line 67) | def factored_buttons_to_groups(self, ac_buttons: np.ndarray, button_gr...
  class IDMActionMapping (line 102) | class IDMActionMapping(ActionMapping):
    method from_factored (line 104) | def from_factored(self, ac: Dict) -> Dict:
    method to_factored (line 107) | def to_factored(self, ac: Dict) -> Dict:
    method get_action_space_update (line 110) | def get_action_space_update(self):
    method get_zero_action (line 117) | def get_zero_action(self):
  class CameraHierarchicalMapping (line 120) | class CameraHierarchicalMapping(ActionMapping):
    method __init__ (line 134) | def __init__(self, *args, **kwargs):
    method _precompute_to_factored (line 151) | def _precompute_to_factored(self):
    method from_factored (line 179) | def from_factored(self, ac: Dict) -> Dict:
    method to_factored (line 215) | def to_factored(self, ac: Dict) -> Dict:
    method get_action_space_update (line 227) | def get_action_space_update(self):
    method get_zero_action (line 233) | def get_zero_action(self):

FILE: metrics/IDM/lib/actions.py
  class Buttons (line 8) | class Buttons:
  class SyntheticButtons (line 36) | class SyntheticButtons:
  class QuantizationScheme (line 43) | class QuantizationScheme:
  class CameraQuantizer (line 49) | class CameraQuantizer:
    method discretize (line 82) | def discretize(self, xy):
    method undiscretize (line 94) | def undiscretize(self, xy):
  class ActionTransformer (line 105) | class ActionTransformer:
    method __init__ (line 109) | def __init__(
    method camera_zero_bin (line 123) | def camera_zero_bin(self):
    method discretize_camera (line 126) | def discretize_camera(self, xy):
    method undiscretize_camera (line 129) | def undiscretize_camera(self, pq):
    method item_embed_id_to_name (line 132) | def item_embed_id_to_name(self, item_id):
    method dict_to_numpy (line 135) | def dict_to_numpy(self, acs):
    method numpy_to_dict (line 154) | def numpy_to_dict(self, acs):
    method policy2env (line 167) | def policy2env(self, acs):
    method env2policy (line 171) | def env2policy(self, acs):

FILE: metrics/IDM/lib/impala_cnn.py
  class CnnBasicBlock (line 13) | class CnnBasicBlock(nn.Module):
    method __init__ (line 20) | def __init__(
    method forward (line 50) | def forward(self, x):
  class CnnDownStack (line 55) | class CnnDownStack(nn.Module):
    method __init__ (line 69) | def __init__(
    method forward (line 114) | def forward(self, x):
    method output_shape (line 123) | def output_shape(self, inshape):
  class ImpalaCNN (line 132) | class ImpalaCNN(nn.Module):
    method __init__ (line 148) | def __init__(
    method forward (line 187) | def forward(self, x):

FILE: metrics/IDM/lib/masked_attention.py
  function get_band_diagonal_mask (line 12) | def get_band_diagonal_mask(t: int, T: int, maxlen: int, batchsize: int, ...
  function get_mask (line 47) | def get_mask(first_b11: th.Tensor, state_mask: th.Tensor, t: int, T: int...
  class MaskedAttention (line 97) | class MaskedAttention(nn.Module):
    method __init__ (line 120) | def __init__(
    method initial_state (line 153) | def initial_state(self, batchsize: int, device=None):
    method forward (line 161) | def forward(self, input_bte, first_bt, state):
    method get_log_keys (line 180) | def get_log_keys(self):

FILE: metrics/IDM/lib/minecraft_util.py
  function store_args (line 12) | def store_args(method):
  function get_norm_entropy_from_cat_head (line 37) | def get_norm_entropy_from_cat_head(module, name, masks, logits):
  function get_norm_cat_entropy (line 62) | def get_norm_cat_entropy(module, masks, logits, template) -> Tuple[torch...
  function get_diag_guassian_entropy (line 77) | def get_diag_guassian_entropy(module, logits, template) -> Optional[torc...

FILE: metrics/IDM/lib/misc.py
  function intprod (line 5) | def intprod(xs):
  function safezip (line 15) | def safezip(*args):
  function transpose (line 26) | def transpose(x, before, after):
  function transpose_undo (line 37) | def transpose_undo(x, before, after, *, undo=None):
  function compose_undo (line 50) | def compose_undo(u1, u2):
  function _parse_reshape_str (line 66) | def _parse_reshape_str(s, kind):
  function _infer_part (line 82) | def _infer_part(part, concrete_dim, known, index, full_shape):
  function _infer_step (line 122) | def _infer_step(args):
  function _infer (line 135) | def _infer(known, desc, shape):
  function fixed_point (line 142) | def fixed_point(f, x, eq=None):
  function _infer_question_mark (line 153) | def _infer_question_mark(x, total_product):
  function _ground (line 172) | def _ground(x, known, infer_question_mark_with=None):
  function _handle_ellipsis (line 181) | def _handle_ellipsis(x, before, after):
  function reshape_undo (line 201) | def reshape_undo(inp, before, after, *, undo=None, known=None, **kwargs):
  function reshape (line 246) | def reshape(*args, **kwargs):
  function product (line 254) | def product(xs, one=1):
  function exact_div (line 261) | def exact_div(a, b):

FILE: metrics/IDM/lib/mlp.py
  class MLP (line 8) | class MLP(nn.Module):
    method __init__ (line 9) | def __init__(self, insize, nhidlayer, outsize, hidsize, hidactiv, dtyp...
    method forward (line 21) | def forward(self, x):
    method output_shape (line 30) | def output_shape(self):

FILE: metrics/IDM/lib/normalize_ewma.py
  class NormalizeEwma (line 6) | class NormalizeEwma(nn.Module):
    method __init__ (line 9) | def __init__(self, input_shape, norm_axes=2, beta=0.99999, per_element...
    method reset_parameters (line 22) | def reset_parameters(self):
    method running_mean_var (line 27) | def running_mean_var(self):
    method forward (line 33) | def forward(self, input_vector):
    method denormalize (line 57) | def denormalize(self, input_vector):

FILE: metrics/IDM/lib/policy.py
  class ImgPreprocessing (line 21) | class ImgPreprocessing(nn.Module):
    method __init__ (line 29) | def __init__(self, img_statistics: Optional[str] = None, scale_img: bo...
    method forward (line 39) | def forward(self, img):
  class ImgObsProcess (line 48) | class ImgObsProcess(nn.Module):
    method __init__ (line 57) | def __init__(
    method forward (line 79) | def forward(self, img):
  class MinecraftPolicy (line 83) | class MinecraftPolicy(nn.Module):
    method __init__ (line 96) | def __init__(
    method output_latent_size (line 190) | def output_latent_size(self):
    method forward (line 193) | def forward(self, ob, state_in, context):
    method initial_state (line 220) | def initial_state(self, batchsize):
  class MinecraftAgentPolicy (line 227) | class MinecraftAgentPolicy(nn.Module):
    method __init__ (line 228) | def __init__(self, action_space, policy_kwargs, pi_head_kwargs):
    method make_value_head (line 237) | def make_value_head(self, v_out_size: int, norm_type: str = "ewma", no...
    method make_action_head (line 240) | def make_action_head(self, pi_out_size: int, **pi_head_opts):
    method initial_state (line 243) | def initial_state(self, batch_size: int):
    method reset_parameters (line 246) | def reset_parameters(self):
    method forward (line 252) | def forward(self, obs, first: th.Tensor, state_in):
    method get_logprob_of_action (line 271) | def get_logprob_of_action(self, pd, action):
    method get_kl_of_action_dists (line 281) | def get_kl_of_action_dists(self, pd1, pd2):
    method get_output_for_observation (line 287) | def get_output_for_observation(self, obs, state_in, first):
    method act (line 308) | def act(self, obs, first, state_in, stochastic: bool = True, taken_act...
    method v (line 331) | def v(self, obs, first, state_in):
  class InverseActionNet (line 342) | class InverseActionNet(MinecraftPolicy):
    method __init__ (line 348) | def __init__(
    method forward (line 374) | def forward(self, ob, state_in, context):
    method _conv3d_forward (line 394) | def _conv3d_forward(self, x):
  class InverseActionPolicy (line 406) | class InverseActionPolicy(nn.Module):
    method __init__ (line 407) | def __init__(
    method make_action_head (line 424) | def make_action_head(self, **kwargs):
    method reset_parameters (line 427) | def reset_parameters(self):
    method forward (line 432) | def forward(self, obs, first: th.Tensor, state_in, **kwargs):
    method predict (line 449) | def predict(
    method initial_state (line 466) | def initial_state(self, batch_size: int):

FILE: metrics/IDM/lib/scaled_mse_head.py
  class ScaledMSEHead (line 11) | class ScaledMSEHead(nn.Module):
    method __init__ (line 16) | def __init__(
    method reset_parameters (line 29) | def reset_parameters(self):
    method forward (line 34) | def forward(self, input_data):
    method loss (line 37) | def loss(self, prediction, target):
    method denormalize (line 45) | def denormalize(self, input_data):
    method normalize (line 49) | def normalize(self, input_data):

FILE: metrics/IDM/lib/torch_util.py
  function contextmanager_to_decorator (line 23) | def contextmanager_to_decorator(cm):
  function have_cuda (line 35) | def have_cuda():
  function default_device_type (line 39) | def default_device_type():
  function set_default_torch_device (line 47) | def set_default_torch_device(device):
  function dev (line 52) | def dev():
  function zeros (line 56) | def zeros(*args, **kwargs):
  function ones (line 60) | def ones(*args, **kwargs):
  function arange (line 64) | def arange(*args, **kwargs):
  function NormedLinear (line 68) | def NormedLinear(*args, scale=1.0, dtype=th.float32, **kwargs):
  class LinearF16 (line 85) | class LinearF16(nn.Linear):
    method forward (line 86) | def forward(self, x):
  class LayerNormF16 (line 90) | class LayerNormF16(nn.LayerNorm):
    method forward (line 91) | def forward(self, x):
  function LayerNorm (line 95) | def LayerNorm(*args, dtype=th.float32, **kwargs):
  function flatten_image (line 107) | def flatten_image(x):
  function sequential (line 115) | def sequential(layers, x, *args, diag_name=None, use_checkpoint=False):
  function load_average_with_metadata (line 122) | def load_average_with_metadata(paths, overrides):
  function save_kwargs (line 135) | def save_kwargs(fn):
  function parse_dtype (line 156) | def parse_dtype(x):
  function index (line 184) | def index(x, i):

FILE: metrics/IDM/lib/tree_util.py
  function unzip2 (line 25) | def unzip2(xys):
  function partial (line 34) | def partial(fun, *args, **kwargs):
  function safe_zip (line 41) | def safe_zip(*args: Collection) -> List[tuple]:
  function safe_map (line 48) | def safe_map(f, *args):
  function tree_map (line 56) | def tree_map(f, tree, treat_as_leaves: Optional[List] = None):
  function tree_multimap (line 79) | def tree_multimap(f, tree, *rest, treat_as_leaves: Optional[List] = None):
  function prefix_multimap (line 113) | def prefix_multimap(f, treedef, tree, *rest):
  function walk_pytree (line 136) | def walk_pytree(f_node, f_leaf, tree, treat_as_leaves: Optional[List] = ...
  function build_tree (line 150) | def build_tree(treedef, xs):
  function _tree_unflatten (line 159) | def _tree_unflatten(xs, treedef):
  function _num_leaves (line 167) | def _num_leaves(treedef):
  function _nested_treedef (line 171) | def _nested_treedef(inner, outer):
  class PyTreeDef (line 180) | class PyTreeDef(object):
    method __init__ (line 181) | def __init__(self, node_type, node_data, children):
    method __repr__ (line 186) | def __repr__(self):
    method __hash__ (line 194) | def __hash__(self):
    method __eq__ (line 197) | def __eq__(self, other):
    method __ne__ (line 203) | def __ne__(self, other):
  class PyLeaf (line 207) | class PyLeaf(object):
    method __repr__ (line 208) | def __repr__(self):
    method __eq__ (line 211) | def __eq__(self, other):
  class NodeType (line 215) | class NodeType(object):
    method __init__ (line 216) | def __init__(self, name, to_iterable, from_iterable):
  function register_pytree_node (line 225) | def register_pytree_node(py_type, to_iterable, from_iterable):
  function tuple_to_iterable (line 230) | def tuple_to_iterable(xs):
  function tuple_from_iterable (line 234) | def tuple_from_iterable(_keys, xs):
  function list_to_iterable (line 238) | def list_to_iterable(xs):
  function list_from_iterable (line 242) | def list_from_iterable(_keys, xs):
  function dict_to_iterable (line 246) | def dict_to_iterable(xs):
  function dict_from_iterable (line 251) | def dict_from_iterable(keys, xs):
  function ordered_dict_from_iterable (line 255) | def ordered_dict_from_iterable(keys, xs):
  function default_dict_to_iterable (line 259) | def default_dict_to_iterable(xs):
  function default_dict_from_iterable (line 263) | def default_dict_from_iterable(keys, xs):
  function none_to_iterable (line 267) | def none_to_iterable(_xs):
  function none_from_iterable (line 271) | def none_from_iterable(_keys, _xs):

FILE: metrics/IDM/lib/util.py
  function get_module_log_keys_recursive (line 13) | def get_module_log_keys_recursive(m: nn.Module):
  class FanInInitReLULayer (line 23) | class FanInInitReLULayer(nn.Module):
    method __init__ (line 38) | def __init__(
    method forward (line 75) | def forward(self, x):
    method get_log_keys (line 84) | def get_log_keys(self):
  class ResidualRecurrentBlocks (line 91) | class ResidualRecurrentBlocks(nn.Module):
    method __init__ (line 93) | def __init__(
    method forward (line 115) | def forward(self, x, first, state):
    method initial_state (line 125) | def initial_state(self, batchsize):
  class ResidualRecurrentBlock (line 132) | class ResidualRecurrentBlock(nn.Module):
    method __init__ (line 134) | def __init__(
    method forward (line 193) | def forward(self, x, first, state):
  function recurrent_forward (line 214) | def recurrent_forward(module, x, first, state, reverse_lstm=False):
  function _banded_repeat (line 232) | def _banded_repeat(x, t):
  function bandify (line 250) | def bandify(b_nd, t, T):
  function get_norm (line 270) | def get_norm(name, d, dtype=th.float32):

FILE: metrics/IDM/lib/xf.py
  function attention (line 18) | def attention(
  class Attn (line 74) | class Attn:
    method __init__ (line 85) | def __init__(self, mask, maxlen):
    method preproc_qkv (line 89) | def preproc_qkv(self, Q_bte, K_bte, V_bte):
    method preproc_r (line 92) | def preproc_r(self, R_btn):
  function split_heads (line 96) | def split_heads(x_bte, h):
  class All2All (line 106) | class All2All(Attn):
    method __init__ (line 107) | def __init__(self, nhead, maxlen, mask=True, head_dim=None):
    method preproc_qkv (line 113) | def preproc_qkv(self, *xs):
    method preproc_r (line 121) | def preproc_r(self, R_btn):
    method postproc_a (line 125) | def postproc_a(self, A_Btq, h):
  function _required_padding (line 134) | def _required_padding(dim, target_div):
  class StridedAttn (line 141) | class StridedAttn(Attn):
    method __init__ (line 142) | def __init__(self, nhead, stride, maxlen, mask=True):
    method _preproc (line 147) | def _preproc(self, x, name, Q_t=None, Q_pad=None):
    method preproc_qkv (line 188) | def preproc_qkv(self, Q_bte, K_bte, V_bte):
    method preproc_r (line 214) | def preproc_r(self, R_bte):
  class AttentionLayerBase (line 229) | class AttentionLayerBase(nn.Module):
    method __init__ (line 230) | def __init__(
    method relattn_logits (line 265) | def relattn_logits(self, X_bte, T):
  function quick_gelu (line 274) | def quick_gelu(x):
  function act (line 278) | def act(actname, x):
  class SelfAttentionLayer (line 289) | class SelfAttentionLayer(AttentionLayerBase):
    method __init__ (line 296) | def __init__(
    method residual (line 334) | def residual(self, X_bte, state):
    method forward (line 358) | def forward(self, X_bte, state):
    method stateless_forward (line 362) | def stateless_forward(self, X_bte):
    method update_state (line 366) | def update_state(self, state, K_bte, V_bte):
    method initial_state (line 393) | def initial_state(self, batchsize, initial_T=0):
    method empty_state (line 399) | def empty_state(self):
  class PointwiseLayer (line 403) | class PointwiseLayer(nn.Module):
    method __init__ (line 408) | def __init__(self, x_size, scale, dtype, norm, actname="relu", mlp_rat...
    method residual (line 423) | def residual(self, x):
    method forward (line 428) | def forward(self, x):
  function _is_separate (line 432) | def _is_separate(sep, name):
  function make_maybe_multiscale (line 443) | def make_maybe_multiscale(make_fn, *args, seqlens, separate, name, **kwa...
  class SplitCallJoin (line 457) | class SplitCallJoin(nn.Module):
    method __init__ (line 458) | def __init__(self, mods, seqlens):
    method forward (line 463) | def forward(self, x):

FILE: metrics/common_metrics.py
  function load_videos_to_tensor (line 16) | def load_videos_to_tensor(video_dir, number_of_videos, video_length, cha...
  function main (line 58) | def main():

FILE: metrics/tabulate_all_results.py
  function tabluate_metrics (line 9) | def tabluate_metrics(input_dir,output_path):

FILE: mineworld.py
  class Buttons (line 20) | class Buttons:
  function get_args (line 167) | def get_args():
  function make_action_dict (line 178) | def make_action_dict(action_line):
  function stack_images (line 189) | def stack_images(imgs):
  function get_action_line (line 196) | def get_action_line(acts):
  function run_prediction (line 206) | def run_prediction(btns_choices, cam_x_input, cam_y_input):
  function run_prediction_n_times (line 251) | def run_prediction_n_times(n, btns_1, btns_2, btns_3, btns_4, btns_5, ca...
  function step_pred_source_video_right (line 262) | def step_pred_source_video_right(video_path, start):
  function on_download_button_click (line 269) | def on_download_button_click(fps=6):
  function cleanup_files (line 295) | def cleanup_files():
  function step_video (line 305) | def step_video(video_path, start_frame, frame_count):

FILE: utils.py
  function print0 (line 8) | def print0(*args, **kwargs):
  function tensor_to_uint8 (line 11) | def tensor_to_uint8(tensor):
  function get_obj_from_str (line 17) | def get_obj_from_str(string, reload=False):
  function instantiate_from_config (line 24) | def instantiate_from_config(config):
  function load_model_from_config (line 29) | def load_model_from_config(config, sd, gpu=True, eval_mode=True):
  function get_valid_dirs (line 41) | def get_valid_dirs(dir1: str, dir2: Union[None, str] = None, dir3: Union...
  function get_valid_paths (line 47) | def get_valid_paths(path1: str, path2: Union[None, str] = None, path3: U...
  function load_model (line 53) | def load_model(config, ckpt, gpu, eval_mode):

FILE: vae.py
  class VAE (line 8) | class VAE(nn.Module):
    method __init__ (line 9) | def __init__(self,
    method init_from_ckpt (line 22) | def init_from_ckpt(
    method tokenize_images (line 46) | def tokenize_images(self, x: torch.Tensor, sane_index_shape: bool = Tr...
    method token2image (line 59) | def token2image(self, tokens):
Condensed preview — 40 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (261K chars).
[
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 444,
    "preview": "# Microsoft Open Source Code of Conduct\n\nThis project has adopted the [Microsoft Open Source Code of Conduct](https://op"
  },
  {
    "path": "LICENSE",
    "chars": 1141,
    "preview": "    MIT License\n\n    Copyright (c) Microsoft Corporation.\n\n    Permission is hereby granted, free of charge, to any pers"
  },
  {
    "path": "README.md",
    "chars": 10143,
    "preview": "<div align=\"center\">\n\n# MineWorld <br> <sub>A Real-time Interactive World Model on Minecraft</sub>\n\n[![arXiv](https://im"
  },
  {
    "path": "SECURITY.md",
    "chars": 2656,
    "preview": "<!-- BEGIN MICROSOFT SECURITY.MD V0.0.9 BLOCK -->\n\n## Security\n\nMicrosoft takes the security of our software products an"
  },
  {
    "path": "SUPPORT.md",
    "chars": 1244,
    "preview": "# TODO: The maintainer of this repo has not yet edited this file\r\n\r\n**REPO OWNER**: Do you want Customer Service & Suppo"
  },
  {
    "path": "configs/1200M_16f.yaml",
    "chars": 669,
    "preview": "model:\n  target: lvm.LlamaLVM\n  params:\n    model_class: lvm.LlamaForCausalLM\n    tokenizer_config:\n      target: vae.VA"
  },
  {
    "path": "configs/1200M_32f.yaml",
    "chars": 665,
    "preview": "model:\n  target: lvm.LlamaLVM\n  params:\n    model_class: lvm.LlamaForCausalLM\n    tokenizer_config:\n      target: vae.VA"
  },
  {
    "path": "configs/300M_16f.yaml",
    "chars": 762,
    "preview": "model:\n  target: lvm.LlamaLVM\n  params:\n    model_class: lvm.LlamaForCausalLM\n    tokenizer_config:\n      target: vae.VA"
  },
  {
    "path": "configs/700M_16f.yaml",
    "chars": 668,
    "preview": "model:\n  target: lvm.LlamaLVM\n  params:\n    model_class: lvm.LlamaForCausalLM\n    tokenizer_config:\n      target: vae.VA"
  },
  {
    "path": "configs/700M_32f.yaml",
    "chars": 671,
    "preview": "model:\n  target: lvm.LlamaLVM\n  params:\n    model_class: lvm.LlamaForCausalLM\n    tokenizer_config:\n      target: vae.VA"
  },
  {
    "path": "diagonal_decoding.py",
    "chars": 22248,
    "preview": "import torch\nfrom typing import Optional\nfrom torch.nn.attention import SDPBackend\n\ndef sample_top_k(logits, temperature"
  },
  {
    "path": "inference.py",
    "chars": 7386,
    "preview": "import os\nimport cv2\nimport torch\nimport time\nimport numpy as np\nfrom tqdm import tqdm\nfrom rich import print\nfrom PIL i"
  },
  {
    "path": "lvm.py",
    "chars": 20775,
    "preview": "\"\"\"\n    Wrap the Huggingface Transformers Llama to PyTorch Lightning Module.\n\"\"\"\nimport os\nimport sys\nimport inspect \nim"
  },
  {
    "path": "mcdataset.py",
    "chars": 14454,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\nimport os\nimport json\nimport attr\nimport collections\nimport numpy as np\n"
  },
  {
    "path": "metrics/IDM/inverse_dynamics_model.py",
    "chars": 19275,
    "preview": "# Borrowed from VPT (https://github.com/openai/Video-Pre-Training)\n\nimport numpy as np\nimport torch as th\nimport cv2\nfro"
  },
  {
    "path": "metrics/IDM/lib/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "metrics/IDM/lib/action_head.py",
    "chars": 11356,
    "preview": "import logging\nfrom typing import Any, Tuple\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.func"
  },
  {
    "path": "metrics/IDM/lib/action_mapping.py",
    "chars": 11033,
    "preview": "import abc\nimport itertools\nfrom collections import OrderedDict\nfrom typing import Dict, List\n\nimport numpy as np\nfrom g"
  },
  {
    "path": "metrics/IDM/lib/actions.py",
    "chars": 5871,
    "preview": "import attr\n# import minerl.herobraine.hero.mc as mc\nimport numpy as np\n\nfrom lib.minecraft_util import store_args\n\n\ncla"
  },
  {
    "path": "metrics/IDM/lib/impala_cnn.py",
    "chars": 6088,
    "preview": "import math\nfrom copy import deepcopy\nfrom typing import Dict, List, Optional\n\nfrom torch import nn\nfrom torch.nn import"
  },
  {
    "path": "metrics/IDM/lib/masked_attention.py",
    "chars": 7749,
    "preview": "import functools\n\nimport torch as th\nfrom torch import nn\n\nimport lib.xf as xf\nfrom lib.minecraft_util import store_args"
  },
  {
    "path": "metrics/IDM/lib/minecraft_util.py",
    "chars": 3532,
    "preview": "import functools\nimport inspect\nfrom typing import Optional, Tuple\n\nimport numpy as np\nimport torch\n\nfrom lib.action_hea"
  },
  {
    "path": "metrics/IDM/lib/misc.py",
    "chars": 7808,
    "preview": "import numpy as np\nimport torch as th\n\n\ndef intprod(xs):\n    \"\"\"\n    Product of a sequence of integers\n    \"\"\"\n    out ="
  },
  {
    "path": "metrics/IDM/lib/mlp.py",
    "chars": 914,
    "preview": "import torch as th\nfrom torch import nn\n\nfrom lib import misc\nfrom lib import torch_util as tu\n\n\nclass MLP(nn.Module):\n "
  },
  {
    "path": "metrics/IDM/lib/normalize_ewma.py",
    "chars": 2659,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\n\n\nclass NormalizeEwma(nn.Module):\n    \"\"\"Normalize a vector of obs"
  },
  {
    "path": "metrics/IDM/lib/policy.py",
    "chars": 16089,
    "preview": "from copy import deepcopy\nfrom email import policy\nfrom typing import Dict, Optional\n\nimport numpy as np\nimport torch as"
  },
  {
    "path": "metrics/IDM/lib/scaled_mse_head.py",
    "chars": 1650,
    "preview": "from typing import Dict, Optional\n\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.nn.init as init\n\nf"
  },
  {
    "path": "metrics/IDM/lib/torch_util.py",
    "chars": 5341,
    "preview": "import functools\nimport itertools\nimport math\nimport os\nimport pickle\nimport re\nimport subprocess\nimport tempfile\nfrom c"
  },
  {
    "path": "metrics/IDM/lib/tree_util.py",
    "chars": 9245,
    "preview": "# Copyright 2018 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
  },
  {
    "path": "metrics/IDM/lib/util.py",
    "chars": 9135,
    "preview": "from typing import Dict, Optional\n\nimport torch as th\nfrom torch import nn\nfrom torch.nn import functional as F\n\nimport "
  },
  {
    "path": "metrics/IDM/lib/xf.py",
    "chars": 15822,
    "preview": "\"\"\"\nImplementation of transformer and reshaping-based sparse transformer\n\"\"\"\nimport functools\nimport math\n\nimport torch "
  },
  {
    "path": "metrics/common_metrics.py",
    "chars": 5203,
    "preview": "import sys\nimport os\nsys.path.append(os.getcwd())\nfrom common_metrics_on_video_quality.calculate_fvd import calculate_fv"
  },
  {
    "path": "metrics/tabulate_all_results.py",
    "chars": 1980,
    "preview": "import argparse\nimport os \nimport json\nimport sys\nimport pandas as pd\nimport numpy as np\nfrom rich import print\n\ndef tab"
  },
  {
    "path": "mineworld.py",
    "chars": 16627,
    "preview": "import os\nimport sys\nsys.path.append(os.getcwd())\nimport gradio as gr\nfrom PIL import Image\nimport numpy as np\nimport to"
  },
  {
    "path": "requirements.txt",
    "chars": 254,
    "preview": "torch==2.6.0\ntorchvision==0.21.0\nomegaconf==2.3.0\ntransformers==4.48.1\nopencv-python==4.11.0.86\nattrs==25.3.0\ndiffusers="
  },
  {
    "path": "scripts/compute_metrics.sh",
    "chars": 1576,
    "preview": "#!/bin/bash\n\nVIDEO_RESULTS_ROOT_DEFAULT=\"videos\"\nMETRICS_ROOT_DEFAULT=\"metrics_log\"\nJSONL_PATH_DEFAULT=\"validation/valid"
  },
  {
    "path": "scripts/inference_16f_models.sh",
    "chars": 488,
    "preview": "DATA_ROOT=\"validation/validation\"\n###########################\n#### Inference 1200M models \n###########################\nC"
  },
  {
    "path": "scripts/setup_metrics.sh",
    "chars": 398,
    "preview": "# clone common_metrics_on_video_quality repository\ngit clone git@github.com:CIntellifusion/common_metrics_on_video_quali"
  },
  {
    "path": "utils.py",
    "chars": 2304,
    "preview": "import torch\nimport importlib\nfrom rich import print\nfrom typing import Union\nimport numpy as np\nimport os\n\ndef print0(*"
  },
  {
    "path": "vae.py",
    "chars": 2797,
    "preview": "import torch\nimport torch.nn as nn\nimport diffusers\nfrom safetensors.torch import load_file as load_safetensors\nfrom uti"
  }
]

About this extraction

This page contains the full source code of the microsoft/mineworld GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 40 files (243.3 KB), approximately 63.5k tokens, and a symbol index with 378 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!