Repository: microsoft/mineworld
Branch: main
Commit: 9f49efbcc68a
Files: 40
Total size: 243.3 KB
Directory structure:
gitextract_mrwf1vu0/
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── configs/
│ ├── 1200M_16f.yaml
│ ├── 1200M_32f.yaml
│ ├── 300M_16f.yaml
│ ├── 700M_16f.yaml
│ └── 700M_32f.yaml
├── diagonal_decoding.py
├── inference.py
├── lvm.py
├── mcdataset.py
├── metrics/
│ ├── IDM/
│ │ ├── inverse_dynamics_model.py
│ │ └── lib/
│ │ ├── __init__.py
│ │ ├── action_head.py
│ │ ├── action_mapping.py
│ │ ├── actions.py
│ │ ├── impala_cnn.py
│ │ ├── masked_attention.py
│ │ ├── minecraft_util.py
│ │ ├── misc.py
│ │ ├── mlp.py
│ │ ├── normalize_ewma.py
│ │ ├── policy.py
│ │ ├── scaled_mse_head.py
│ │ ├── torch_util.py
│ │ ├── tree_util.py
│ │ ├── util.py
│ │ └── xf.py
│ ├── common_metrics.py
│ └── tabulate_all_results.py
├── mineworld.py
├── requirements.txt
├── scripts/
│ ├── compute_metrics.sh
│ ├── inference_16f_models.sh
│ └── setup_metrics.sh
├── utils.py
└── vae.py
================================================
FILE CONTENTS
================================================
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Microsoft Open Source Code of Conduct
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
Resources:
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
================================================
FILE: README.md
================================================
# MineWorld
A Real-time Interactive World Model on Minecraft
[](https://arxiv.org/pdf/2504.08388) [](https://aka.ms/mineworld) [](https://huggingface.co/microsoft/mineworld)
We introduce MineWorld, an interactive world model on Minecraft that brings several key advancements over existing approaches:
* 🕹️ **High generation quality**. Built on a visual-action autoregressive Transformer, MineWorld generates coherent, high-fidelity frames conditioned on both visuals and actions.
* 🕹️ **Strong controllability**. We propose benchmarks for the action-following capacity, where MineWorld shows precise and consistent behavior.
* 🕹️ **Fast inference speed**. With Diagonal Decoding, MineWorld achieves a generation rate of 4 to 7 frames per second, enabling real-time interaction in open-ended game environments.
https://github.com/user-attachments/assets/2f5b4740-badd-453c-970d-061abd367f82
## 🔥 News
* May, 2025: The model checkpoints in the [Huggingface repo](https://huggingface.co/microsoft/mineworld) have been temporally taken down.
* April, 2025: 🚀 [MineWorld](https://github.com/microsoft/mineworld) was released!
* March, 2025: 🚀 The paper of [Diagonal Decoding](https://arxiv.org/pdf/2503.14070) was released!
## 🔧 Setup
1. Clone this repository and navigate to MineWorld folder:
```bash
git clone https://github.com/microsoft/mineworld.git
cd mineworld
```
2. We provide an `requirements.txt` file for setting up a pip environment.
```bash
# 1. Prepare conda environment
conda create -n mineworld python=3.10
# 2. Activate the environment
conda activate mineworld
# 3. install our environment
pip3 install -r requirements.txt
```
We recommend using high-end GPU for inference. We have done all testing and development using A100 and H100 GPU.
## 🎈 Checkpoints
Download pre-trained models [here](https://huggingface.co/microsoft/mineworld). Each checkpoint has a corresponding config file with the same name in the `configs` folder in this repository. All models share the same vae checkpoint and config. The data structure is as follows:
```
└── checkpoints
├── 300M_16f.ckpt
├── 700M_16f.ckpt
├── 700M_32f.ckpt
├── 1200M_16f.ckpt
└── 1200M_32f.ckpt
└── vae
├── config.json
└── vae.ckpt
└── validation
└── validation.zip
└── gradio_scene
├── scene.mp4
└── scene.jsonl
```
## 🚀 Inference
We provide two ways to use our model: interacting with it in a web demo, and running locally to reproduce the evaluation results in our paper. In addition to download the checkpoints and place them in the `checkpoints` folder, it is also required to download `scene.mp4` and `scene.jsonl` when running the web demo. Make sure they are placed in the same directory.
### Run Web Demo
To launch the webpage game, run the following command:
```bash
python mineworld.py --scene "path/to/scene.mp4"
--model_ckpt "path/to/ckpt"
--config "path/to/config"
```

Once the demo is running, you can access the website through the local URL or the public URL displayed in the command line. Initialization and the first action may take some time due to compilation.
You can specify a reference frame using the `--reference_frame` option, which should be larger than `4` and smaller than the context length of the model (i.e., `16` or `32` depending on the model utilized). A higher reference frame number generally corresponds to better visual quality. Once the initial state has been set, perform the game actions by selecting options in each chatbox.
The game progresses when pressing the "Run" button, displaying the last `8` frames and the most recent frame separately. Players can also set an action count to repeat an action multiple times.
Explanations to the buttons in the web demo are as follows:
```
Start frame: select a frame in scene.mp4 with its frame index
Jump to start frame: use the selected frame as the initial state
Camera `X` and `Y`: control the camera movements between `-90` and `90` degrees
Other action buttons: same as the actions in Minecraft
Generate video: save previous game progress
```
### Run Local Inference
To run inference locally, use the following command:
```bash
python inference.py \
--data_root "/path/to/validation/dataset" \
--model_ckpt "path/to/ckpt" \
--config "path/to/config" \
--demo_num 1 \
--frames 15 \
--accelerate-algo 'naive' \
--top_p 0.8 \
--output_dir "path/to/output"
```
Check `scripts/inference_16f_models.sh` for examples. To switch between naive autoregressive decoding and diagonal decoding, change the command `--accelerate-algo` to `naive` and `image_diagd` correspondingly.
After the inference of a set of videos, you can compute the metrics and reproduce the numerical results in our paper, check and run the following scripts:
```bash
bash scripts/setup_metrics.sh # only required in the first time
bash scripts/compute_metrics.sh
```
The evalution outputs will have the following structure:
```
└── videos
├── inference_setting1
├── clip_1.mp4
└── clip_1.json
├── inference_setting2
├── clip_1.mp4
└── clip_1.json
└── metrics_log
├── fvd_inference_setting1.json
├── fvd_inference_setting2.json
├── idm_inference_setting1.json
├── idm_inference_setting2.json
└── latest_metrics.csv
```
All results will be aggregated into `metrics_log/latest_metrics.csv`.
## 💡 Intended Uses
Our model is solely trained in the Minecraft game domain. As a world model, an initial image in the game scene will be provided, and the users should select an action from the action list. Then the model will generate the next scene that takes place the selected action.
## 🪧 Out-of-scope Uses
Our models are not specifically designed for any tasks or scenarios other than the Minecraft model.
Developers should expect failures in generation results regarding the out-of-scope scenarios.
Developers should be aware of and adhere to applicable laws or regulations (including privacy, trade compliance laws, etc.) that are relevant to their use case, and evaluate and mitigate for privacy, safety, and fairness before using within a specific downstream use case, particularly for high-risk scenarios.
## 🤖️ Risks and Limitations
Some of the limitations of this model to be aware of include:
* Quality of Service: MineWorld is trained solely on Minecraft, so it cannot generate results for other video domains (such as internet video). And the model cannot generate videos with higher resolution.
* Information Reliability: MineWorld is trained on videos with a fixed resolution, therefore the results may lose detailed information due to the low resolution.
* MineWorld inherits any biases, errors, or omissions characteristic of its training data, which may be amplified by any AI-generated interpretations.
* MineWorld was developed for research and experimental purposes. Further testing and validation are needed before considering its application in commercial or real-world scenarios.
* The input of other images than Minecraft will result in incoherent imagery being created and should not be attempted.
* Users are responsible for sourcing their datasets legally and ethically. This could include securing appropriate copy rights, ensuring consent for use of audio/images, and/or the anonymization of data prior to use in research.
## ✏️ BibTeX
```bibtex
@article{guo2025mineworld,
title={MineWorld: a Real-Time and Open-Source Interactive World Model on Minecraft},
author={Guo, Junliang and Ye, Yang and He, Tianyu and Wu, Haoyu and Jiang, Yushu and Pearce, Tim and Bian, Jiang}
year={2025},
journal={arXiv preprint arXiv:2504.08388},
}
```
## 🤗 Acknowledgments
This codebase borrows code from [VPT](https://github.com/openai/Video-Pre-Training) and [generative-models](https://github.com/Stability-AI/generative-models). We thank them for their efforts and innovations, which have made the development process more efficient and convenient.
Thank you to everyone who contributed their wisdom and efforts to this project.
## ☎️ Contact
We welcome feedback and collaboration from our audience. If you have suggestions, questions, or observe unexpected/offensive behavior in our technology, please contact us through `tianyuhe AT microsoft.com`.
## 📄 Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
## 📍 Trademarks
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
trademarks or logos is subject to and must follow
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
Any use of third-party trademarks or logos are subject to those third-party's policies.
================================================
FILE: SECURITY.md
================================================
## Security
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
## Reporting Security Issues
**Please do not report security vulnerabilities through public GitHub issues.**
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
* Full paths of source file(s) related to the manifestation of the issue
* The location of the affected source code (tag/branch/commit or direct URL)
* Any special configuration required to reproduce the issue
* Step-by-step instructions to reproduce the issue
* Proof-of-concept or exploit code (if possible)
* Impact of the issue, including how an attacker might exploit the issue
This information will help us triage your report more quickly.
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
## Preferred Languages
We prefer all communications to be in English.
## Policy
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
================================================
FILE: SUPPORT.md
================================================
# TODO: The maintainer of this repo has not yet edited this file
**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
- **No CSS support:** Fill out this template with information about how to file issues and get help.
- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
- **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
# Support
## How to file issues and get help
This project uses GitHub Issues to track bugs and feature requests. Please search the existing
issues before filing new issues to avoid duplicates. For new issues, file your bug or
feature request as a new Issue.
For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
## Microsoft Support Policy
Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
================================================
FILE: configs/1200M_16f.yaml
================================================
model:
target: lvm.LlamaLVM
params:
model_class: lvm.LlamaForCausalLM
tokenizer_config:
target: vae.VAE
params:
config_path: checkpoints/vae/config.json
ckpt_path: checkpoints/vae/vae.ckpt
transformer_config:
target: transformers.LlamaConfig
params:
max_position_embeddings: 5552
hidden_size: 2048
intermediate_size: 8192
num_attention_heads: 32
num_key_value_heads: 4
num_hidden_layers: 20
rope_theta: 10000.0
torch_dtype: bfloat16
rms_norm_eps: 1.0e-05
vocab_size: 8262
attention_bias: false
mlp_bias: false
================================================
FILE: configs/1200M_32f.yaml
================================================
model:
target: lvm.LlamaLVM
params:
model_class: lvm.LlamaForCausalLM
tokenizer_config:
target: vae.VAE
params:
config_path: checkpoints/vae/config.json
ckpt_path: checkpoints/vae/vae.ckpt
transformer_config:
target: transformers.LlamaConfig
params:
max_position_embeddings: 11104
hidden_size: 2048
intermediate_size: 8192
num_attention_heads: 32
num_key_value_heads: 4
num_hidden_layers: 20
rope_theta: 10000.0
torch_dtype: bfloat16
rms_norm_eps: 1.0e-05
vocab_size: 8262
attention_bias: false
mlp_bias: false
================================================
FILE: configs/300M_16f.yaml
================================================
model:
target: lvm.LlamaLVM
params:
model_class: lvm.LlamaForCausalLM
tokenizer_config:
target: vae.VAE
params:
config_path: checkpoints/vae/config.json
ckpt_path: checkpoints/vae/vae.ckpt
transformer_config:
target: transformers.LlamaConfig
params:
max_position_embeddings: 5552
hidden_size: 1024
intermediate_size: 4096
num_attention_heads: 16
num_key_value_heads: 4
num_hidden_layers: 20
initializer_range: 0.02
rope_theta: 10000.0
torch_dtype: bfloat16
rms_norm_eps: 1.0e-05
vocab_size: 8262
attention_bias: false
mlp_bias: false
token_num: 347
image_num: 336
frame: 16
================================================
FILE: configs/700M_16f.yaml
================================================
model:
target: lvm.LlamaLVM
params:
model_class: lvm.LlamaForCausalLM
tokenizer_config:
target: vae.VAE
params:
config_path: checkpoints/vae/config.json
ckpt_path: checkpoints/vae/vae.ckpt
transformer_config:
target: transformers.LlamaConfig
params:
max_position_embeddings: 5552
hidden_size: 2048
intermediate_size: 4096
num_attention_heads: 32
num_key_value_heads: 4
num_hidden_layers: 20
rope_theta: 10000.0
torch_dtype: bfloat16
rms_norm_eps: 1.0e-05
vocab_size: 8262
attention_bias: false
mlp_bias: false
================================================
FILE: configs/700M_32f.yaml
================================================
model:
target: lvm.LlamaLVM
params:
model_class: lvm.LlamaForCausalLM
tokenizer_config:
target: vae.VAE
params:
config_path: checkpoints/vae/config.json
ckpt_path: checkpoints/vae/vae.ckpt
transformer_config:
target: transformers.LlamaConfig
params:
max_position_embeddings: 11104
hidden_size: 2048
intermediate_size: 4096
num_attention_heads: 32
num_key_value_heads: 4
num_hidden_layers: 20
rope_theta: 10000.0
torch_dtype: bfloat16
rms_norm_eps: 1.0e-05
vocab_size: 8262
attention_bias: false
mlp_bias: false
================================================
FILE: diagonal_decoding.py
================================================
import torch
from typing import Optional
from torch.nn.attention import SDPBackend
def sample_top_k(logits, temperature: float = 1.0, top_k: Optional[int] = None, vocab_size=8192):
"""
Sample from the logits using top-k sampling.
Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
"""
# logits: [batch_size, seq_len, vocab_size]
if temperature == 0.0:
idx_next = torch.argmax(logits[:, -1, :vocab_size], dim=-1, keepdim=True)
else:
probs = logits_to_probs(logits[:, -1, :vocab_size], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next
def multinomial_sample_one_no_sync(probs_sort, dtype=torch.int):
"""
Multinomial sampling without a cuda synchronization.
Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
"""
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=dtype)
def logits_to_probs(
logits,
temperature: float = 1.0,
top_k: Optional[int] = None,
):
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def sample_top_p(logits, temperature, top_p, vocab_size=8192):
probs = torch.softmax(logits[:, -1, :vocab_size] / temperature, dim=-1)
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = multinomial_sample_one_no_sync(probs_sort, dtype=torch.int64)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
def sample_n_top_p(logits, temperature, top_p, vocab_size=8192):
probs = torch.softmax(logits[:, :, :vocab_size] / temperature, dim=-1)
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = multinomial_sample_one_no_sync(probs_sort, dtype=torch.int64)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
def sample_n_top_k(logits, temperature: float = 1.0, top_k: Optional[int] = None, vocab_size=8192):
if temperature == 0.0:
# Modify for multiple logits (n items)
idx_next = torch.argmax(logits[:, :, :vocab_size], dim=-1, keepdim=True) # Use all n logits for top-k
probs = None
else:
probs = logits_to_n_probs(logits[:, :, :vocab_size], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next
def logits_to_n_probs(
logits,
temperature: float = 1.0,
top_k: Optional[int] = None,
):
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)), dim=-1)
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def decode_one_token(
model,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
):
"""
Decode a single token from the autoregressive model.
"""
logits = model(input_ids=input_ids, position_ids=position_ids)
if top_p is not None:
return sample_top_p(logits, temperature=temperature, top_p=top_p)
else:
return sample_top_k(logits, temperature=temperature, top_k=top_k)
def decode_some_token(
model,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
):
"""
Decode multi token from the autoregressive model.
"""
logits = model(input_ids=input_ids, position_ids=position_ids)
if top_p is not None:
return sample_n_top_p(logits, temperature=temperature, top_p=top_p)
else:
return sample_n_top_k(logits, temperature=temperature, top_k=top_k)
def decode_n_tokens(
model,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
num_generate_tokens: int,
temperature: float = 1.0,
top_p: Optional[float] = 0.8,
top_k: Optional[int] = None,
decode_one_token_function=decode_one_token,
pixnum: int = 336,
actnum: int = 11,
**kwargs,
):
"""
Decode n tokens from the autoregressive model.
Adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
"""
new_tokens = [input_ids]
pos_ = position_ids
assert (
top_p is None or top_k is None
), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
for t in range(num_generate_tokens):
with torch.nn.attention.sdpa_kernel(
SDPBackend.MATH
): # Actually better for Inductor to codegen attention here
next_token = decode_one_token_function(
model,
input_ids=input_ids,
position_ids=position_ids,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
pos_ += 1
position_ids = pos_
new_tokens.append(next_token.clone())
input_ids = next_token.clone()
if (pos_ - pixnum + 1) % (actnum + pixnum) == 0 and t+2 < num_generate_tokens:
action = kwargs["action"][ (t+2) // pixnum ]
input_ids = torch.cat((input_ids, action), dim=-1)
position_ids = torch.tensor([pos_ + _ for _ in range(actnum+1)], dtype=torch.long, device="cuda")
pos_ += actnum
return new_tokens
def decode_n_tokens_for_gradio(
model,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
num_generate_tokens: int,
temperature: float = 1.0,
top_p: Optional[float] = 0.8,
top_k: Optional[int] = None,
decode_one_token_function=decode_one_token,
):
"""
Decode n tokens from the autoregressive model.
Adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
"""
new_tokens = []
assert (
top_p is None or top_k is None
), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
position_id = position_ids[-1].unsqueeze(0)
assert num_generate_tokens % 336 == 1, "should be pixnum x n + 1 to fill kvcache"
for t in range(num_generate_tokens):
with torch.nn.attention.sdpa_kernel(
SDPBackend.MATH
): # Actually better for Inductor to codegen attention here
next_token = decode_one_token_function(
model,
input_ids=input_ids,
position_ids=position_ids,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
position_id += 1
position_ids = position_id
new_tokens.append(next_token.clone())
input_ids = next_token.clone()
return new_tokens[:-1], position_id
def prefill(
model,
input_ids: torch.Tensor = None,
position_ids: torch.Tensor = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = 0.8,
**kwargs,
):
logits = model(input_ids=input_ids, position_ids=position_ids)
# Only top-p or top-k can be provided
assert (
top_p is None or top_k is None
), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
if top_p is not None:
return sample_top_p(logits, temperature=temperature, top_p=top_p)
else:
return sample_top_k(logits, temperature=temperature, top_k=top_k)
def img_diagd_prepare_inputs(
ongoing_row_list,
row_token_num,
ongoing_input,
prompt,
imagenum,
pixnum: int = 336,
actnum: int = 11,
columnnum: int = 24,
promptlen: int = 347,
**kwargs
):
position_ids = []
for i in ongoing_row_list:
global_idx = promptlen + i * columnnum + row_token_num[i] - 1 + (imagenum - 1) * (pixnum + actnum)
position_ids.append(global_idx)
if row_token_num[ongoing_row_list[-1]] == 0:
append_policy = kwargs.get("append_policy", True)
if append_policy:
idx_in_input_ids = ongoing_row_list[-1] * columnnum - 1
ongoing_input.append(prompt[:, idx_in_input_ids].unsqueeze(-1))
else:
ongoing_input.append(ongoing_input[-1])
input_ids = torch.cat(ongoing_input, dim=1)
position_ids = torch.tensor(position_ids, device="cuda")
return input_ids, position_ids
def img_diagd_decode_n_tokens(
model,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
num_generate_tokens: int,
temperature: float = 1.0,
top_p: Optional[float] = 0.8,
top_k: Optional[int] = None,
decode_some_token_function=decode_some_token,
pixnum: int = 336,
actnum: int = 11,
columnnum: int = 24,
rownum: int = 14,
windowsize: int = 2,
promptlen: int = 347,
**kwargs,
):
assert (
top_p is None or top_k is None
), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
imagenum = 1
cur_len = 1
num_generate_tokens += 1
prompt = kwargs.pop("prompt", None)
new_tokens = [input_ids.clone()]
row_token_num = torch.zeros((rownum,), dtype=torch.long, device="cuda")
row_token_num[0] += 1
ongoing_row_list = [0]
ongoing_input = [input_ids.clone()]
while True:
if cur_len >= num_generate_tokens:
break
if cur_len % pixnum == 0 :#and image_start_token_id_index is None:
imagenum += 1
action = kwargs["action"][cur_len // pixnum]
ongoing_input.append(action)
input_id = torch.cat(ongoing_input, dim=-1)
position_ids = torch.arange(imagenum * (pixnum+actnum) - actnum - 1, imagenum * (pixnum+actnum), device="cuda")
image_token_num = cur_len % pixnum
if image_token_num == 1 and row_token_num[0] == windowsize:
ongoing_row_list.append(1)
if image_token_num >= 1:
input_id, position_ids = img_diagd_prepare_inputs(ongoing_row_list=ongoing_row_list, ongoing_input = ongoing_input, imagenum=imagenum, row_token_num=row_token_num, promptlen=promptlen, prompt=prompt,**kwargs)
num_new_tokens = input_id.shape[1] if len(ongoing_row_list) > 0 else 1
with torch.nn.attention.sdpa_kernel(
SDPBackend.MATH
): # Actually better for Inductor to codegen attention here
next_token = decode_some_token_function(
model,
input_ids=input_id,
position_ids=position_ids,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
ongoing_input = []
if len(ongoing_row_list) == 0:
cur_len += 1
ongoing_input.append(next_token[:,-1].clone())
new_tokens.append(next_token[:,-1].clone())
ongoing_row_list.append(0)
row_token_num[0] += 1
else:
need_remove_row = None
cur_len += num_new_tokens
for i in range(num_new_tokens):
position_in_new_tokens = torch.sum(row_token_num[:(ongoing_row_list[i] + 1)], dim=0) + (imagenum - 1) * pixnum
new_tokens.insert(position_in_new_tokens, next_token[:,i].clone())
ongoing_input.append(next_token[:,i].clone())
row_token_num[ongoing_row_list[i]] += 1
if row_token_num[ongoing_row_list[i]] == windowsize and ongoing_row_list[i] < rownum - 1:
ongoing_row_list.append(ongoing_row_list[i]+1)
elif ongoing_row_list[i] == rownum - 1 and row_token_num[ongoing_row_list[i]] == columnnum:
row_token_num = torch.zeros((rownum,), dtype=torch.long, device="cuda")
ongoing_row_list = []
ongoing_input = [next_token[:,i]]
need_remove_row = None
break
if row_token_num[ongoing_row_list[i]] == columnnum: ## this row is done
ongoing_input.pop()
need_remove_row = ongoing_row_list[i]
if need_remove_row is not None:
ongoing_row_list.remove(need_remove_row)
return new_tokens
def img_diagd_prepare_inputs_for_gradio(
ongoing_row_list,
row_token_num,
ongoing_input,
pixnum: int = 336,
actnum: int = 11,
columnnum: int = 24,
promptlen: int = 347,
):
position_ids = []
for i in ongoing_row_list:
global_idx = promptlen + i * columnnum + row_token_num[i] - 1
position_ids.append(global_idx)
if row_token_num[ongoing_row_list[-1]] == 0:
ongoing_input.append(ongoing_input[-1])
input_ids = torch.cat(ongoing_input, dim=1)
position_ids = torch.tensor(position_ids, device="cuda")
return input_ids, position_ids
def img_diagd_decode_n_token_for_gradio(
model,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
num_generate_tokens: int,
temperature: float = 1.0,
top_p: Optional[float] = 0.8,
top_k: Optional[int] = None,
decode_some_token_function=decode_some_token,
pixnum: int = 336,
columnnum: int = 24,
rownum: int = 14,
windowsize: int = 2,
):
assert (
top_p is None or top_k is None
), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
cur_len = 0
promptlen = position_ids[-1] + 1
new_tokens = []
row_token_num = torch.zeros((rownum,), dtype=torch.long, device="cuda")
ongoing_row_list = []
ongoing_input = []
while True:
if cur_len == num_generate_tokens:
break
image_token_num = cur_len
if image_token_num == 1 and row_token_num[0] == windowsize:
ongoing_row_list.append(1)
if image_token_num == 0:
input_id = input_ids
if image_token_num >=1:
input_id, position_ids = img_diagd_prepare_inputs_for_gradio(ongoing_row_list=ongoing_row_list, ongoing_input = ongoing_input, row_token_num=row_token_num, promptlen=promptlen)
num_new_tokens = input_id.shape[1] if len(ongoing_row_list) > 0 else 1
with torch.nn.attention.sdpa_kernel(
SDPBackend.MATH
): # Actually better for Inductor to codegen attention here
next_token = decode_some_token_function(
model,
input_ids=input_id,
position_ids=position_ids,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
ongoing_input = []
if len(ongoing_row_list) == 0:
cur_len += 1
ongoing_input.append(next_token[:,-1].clone())
new_tokens.append(next_token[:,-1].clone())
ongoing_row_list.append(0)
row_token_num[0] += 1
else:
need_remove_row = None
cur_len += num_new_tokens
for i in range(num_new_tokens):
position_in_new_tokens = torch.sum(row_token_num[:(ongoing_row_list[i] + 1)], dim=0)
new_tokens.insert(position_in_new_tokens, next_token[:,i].clone())
ongoing_input.append(next_token[:,i].clone())
row_token_num[ongoing_row_list[i]] += 1
if row_token_num[ongoing_row_list[i]] == windowsize and ongoing_row_list[i] < rownum - 1:
ongoing_row_list.append(ongoing_row_list[i]+1)
elif ongoing_row_list[i] == rownum - 1 and row_token_num[ongoing_row_list[i]] == columnnum:
row_token_num = torch.zeros((rownum,), dtype=torch.long, device="cuda")
ongoing_row_list = []
ongoing_input = [next_token[:,i]]
need_remove_row = None
break
if row_token_num[ongoing_row_list[i]] == columnnum: ## this row is done
ongoing_input.pop()
need_remove_row = ongoing_row_list[i]
if need_remove_row is not None:
ongoing_row_list.remove(need_remove_row)
with torch.nn.attention.sdpa_kernel(
SDPBackend.MATH
): # Actually better for Inductor to codegen attention here
_ = decode_some_token_function(
model,
input_ids=next_token[:,-1],
position_ids=position_ids+1,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
return new_tokens, position_ids+2
def vid_diagd_prepare_inputs(
ongoing_row_list_v,
row_token_num_v,
ongoing_input_v,
prompt,
pixnum: int = 336,
actnum: int = 11,
rownum: int = 14,
columnnum: int = 24,
promptlen: int = 347,
**kwargs
):
new_frame = False
position_ids = []
for i in ongoing_row_list_v:
global_idx = promptlen + i * columnnum + row_token_num_v[i // rownum][i % rownum] -1 + (i // rownum) * actnum
position_ids.append(global_idx)
lastrow = ongoing_row_list_v[-1]
if lastrow % rownum == 0 and row_token_num_v[lastrow // rownum][lastrow % rownum] == 0:
# WARNING
action = kwargs["action"][lastrow // rownum]
ongoing_input_v.append(action)
position_ids.pop()
pos_act = torch.arange( promptlen + (lastrow // rownum) * (pixnum+actnum) - actnum, promptlen + (lastrow // rownum) * (pixnum+actnum), device="cuda")
position_ids.extend(pos_act.unbind())
new_frame = True
elif row_token_num_v[lastrow // rownum][lastrow % rownum] == 0:
append_policy = kwargs.get("append_policy", True)
if append_policy:
idx_in_input_ids = (lastrow % rownum) * columnnum - 1
ongoing_input_v.append(prompt[:, idx_in_input_ids].unsqueeze(-1))
else:
ongoing_input_v.append(ongoing_input_v[-1])
input_ids = torch.cat(ongoing_input_v, dim=1)
position_ids = torch.tensor(position_ids, device="cuda")
return input_ids, position_ids, new_frame
def video_diagd_decode_n_tokens(
model,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
num_generate_tokens: int,
temperature: float = 1.0,
top_p: Optional[float] = 0.8,
top_k: Optional[int] = None,
decode_some_token_function=decode_some_token,
pixnum: int = 336,
actnum: int = 11,
columnnum: int = 24,
rownum: int = 14,
windowsize: int = 2,
promptlen: int = 347,
**kwargs,
):
assert (
top_p is None or top_k is None
), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
cur_len = 1
num_generate_tokens += 1
prompt = kwargs.pop("prompt", None)
new_tokens = [input_ids.clone()]
row_token_num_v = []
ongoing_row_list_v = [0]
row_token_num_v.append(torch.zeros((rownum,), dtype=torch.long, device="cuda"))
row_token_num_v[0][0] += 1
if row_token_num_v[0][0] == windowsize:
ongoing_row_list_v.append(1)
ongoing_input_v = [input_ids.clone()]
while True:
if cur_len >= num_generate_tokens:
break
input_id, position_ids, new_frame = vid_diagd_prepare_inputs(ongoing_row_list_v=ongoing_row_list_v, ongoing_input_v = ongoing_input_v, row_token_num_v=row_token_num_v, promptlen=promptlen, prompt=prompt, **kwargs)
num_new_tokens = input_id.shape[1]
with torch.nn.attention.sdpa_kernel(
SDPBackend.MATH
): # Actually better for Inductor to codegen attention here
next_token = decode_some_token_function(
model,
input_ids=input_id,
position_ids=position_ids,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
ongoing_input_v = []
if new_frame:
next_token = torch.cat([next_token[:,:-actnum], next_token[:,-1:]], dim=1)
num_new_tokens = num_new_tokens - actnum + 1
need_remove_row = None
cur_len += num_new_tokens
for i in range(num_new_tokens):
last_frame = torch.stack(row_token_num_v[:ongoing_row_list_v[i] // rownum]).sum() if ongoing_row_list_v[i] // rownum > 0 else torch.tensor(0, dtype=torch.long, device="cuda")
position_in_new_tokens = last_frame + torch.sum(row_token_num_v[ongoing_row_list_v[i] // rownum][:(ongoing_row_list_v[i] % rownum + 1)], dim=0)
new_tokens.insert(position_in_new_tokens, next_token[:,i].clone())
ongoing_input_v.append(next_token[:,i].clone())
row_token_num_v[ongoing_row_list_v[i] // rownum][ongoing_row_list_v[i] % rownum] += 1
# WARNING
if row_token_num_v[ongoing_row_list_v[i] // rownum][ongoing_row_list_v[i] % rownum] == windowsize and ongoing_row_list_v[i] < rownum * (num_generate_tokens//pixnum) - 1:
ongoing_row_list_v.append(ongoing_row_list_v[i]+1)
if ongoing_row_list_v[-1] % rownum == 0:
row_token_num_v.append(torch.zeros((rownum,), dtype=torch.long, device="cuda"))
if row_token_num_v[ongoing_row_list_v[i] // rownum][ongoing_row_list_v[i] % rownum] == columnnum:
ongoing_input_v.pop()
need_remove_row = ongoing_row_list_v[i]
if need_remove_row is not None:
ongoing_row_list_v.remove(need_remove_row)
return new_tokens
================================================
FILE: inference.py
================================================
import os
import cv2
import torch
import time
import numpy as np
from tqdm import tqdm
from rich import print
from PIL import Image
from pathlib import Path
from torch import autocast
from einops import rearrange
from mcdataset import MCDataset
from omegaconf import OmegaConf
from torchvision import transforms
from argparse import ArgumentParser
from utils import load_model, tensor_to_uint8
torch.backends.cuda.matmul.allow_tf32 = False
ACCELERATE_ALGO = [
'naive','image_diagd'
]
TARGET_SIZE=(224,384)
TOKEN_PER_IMAGE = 347 # IMAGE = PIX+ACTION
TOKEN_PER_PIX = 336
safe_globals = {"array": np.array}
def token2video(code_list, tokenizer, save_path, fps, device = 'cuda'):
"""
change log: we don't perform path processing inside functions to enable extensibility
save_path: str, path to save the video, expect to endwith .mp4
"""
if len(code_list) % TOKEN_PER_PIX != 0:
print(f"code_list length {len(code_list)} is not multiple of {TOKEN_PER_PIX}")
return
num_images = len(code_list) // TOKEN_PER_PIX
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video = cv2.VideoWriter(save_path, fourcc, fps, (384, 224))
for i in range(num_images):
code = code_list[i*TOKEN_PER_PIX:(i+1)*TOKEN_PER_PIX]
code = torch.tensor([int(x) for x in code], dtype=torch.long).to(device)
img = tokenizer.token2image(code) # pixel
frame = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
video.write(frame)
video.release()
def get_args():
parser = ArgumentParser()
parser.add_argument('--data_root', type=str, required=True)
parser.add_argument('--model_ckpt', type=str, required=True)
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--output_dir', type=str, required=True)
parser.add_argument('--demo_num', type=int, default=1)
parser.add_argument('--frames', type=int, required=True)
parser.add_argument('--window_size', type=int, default=2)
parser.add_argument('--accelerate-algo', type=str, default='naive', help=f"Accelerate Algorithm Option: {ACCELERATE_ALGO}")
parser.add_argument('--fps', type=int, default=6)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--top_k', type=int, help='Use top-k sampling')
group.add_argument('--top_p', type=float, help='Use top-p (nucleus) sampling')
parser.add_argument('--val_data_num', type=int, default=500, help="number of validation data")
args = parser.parse_args()
return args
def lvm_generate(args, model, output_dir, demo_video):
"""
"""
### 1. set video input/output path
input_mp4_path = os.path.join(args.data_root, demo_video)
input_action_path = os.path.join(args.data_root, demo_video.replace('mp4','jsonl'))
output_mp4_path = str(output_dir / demo_video)
output_action_path = output_mp4_path.replace('.mp4', '.jsonl')
# backup action
os.system(f"cp {input_action_path} {output_action_path}")
if os.path.exists(output_mp4_path):
print(f"output path {output_mp4_path} exist")
return {}
device = model.transformer.device
### 2. load action into list
action_list = []
mcdataset = MCDataset()
with open(input_action_path, 'r') as f:
for line in f:
line = eval(line.strip(), {"__builtins__": None}, safe_globals)
line['camera'] = np.array(line['camera'])
act_index = mcdataset.get_action_index_from_actiondict(line, action_vocab_offset=8192)
action_list.append(act_index)
### 3. load video frames
cap = cv2.VideoCapture(input_mp4_path)
start_frame = 0
end_frame = args.demo_num
frames = []
for frame_idx in range(start_frame, end_frame):
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
if not ret:
print(f"Error in reading frame {frame_idx}")
continue
cv2.cvtColor(frame, code=cv2.COLOR_BGR2RGB, dst=frame)
frame = np.asarray(np.clip(frame, 0, 255), dtype=np.uint8)
frame = torch.from_numpy(frame)
frames.append(frame)
frames = torch.stack(frames, dim=0).to(device)
frames = frames.permute(0, 3, 1, 2)
frames = frames.float() / 255.0
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
frames = normalize(frames)
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
img_index = model.tokenizer.tokenize_images(frames)
img_index = rearrange(img_index, '(b t) h w -> b t (h w)', b=1)
all_generated_tokens = []
action_all = action_list[end_frame: end_frame + args.frames]
action_all = torch.tensor(action_all).unsqueeze(1).to(device)
image_input = rearrange(img_index, 'b t c -> b (t c)')
start_t = time.time()
with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16):
if args.accelerate_algo == 'naive':
outputs = model.transformer.naive_generate( input_ids=image_input, max_new_tokens=TOKEN_PER_PIX*args.frames, action_all=action_all, top_k=args.top_k, top_p=args.top_p)
elif args.accelerate_algo == 'image_diagd':
outputs = model.transformer.img_diagd_generate(input_ids=image_input, max_new_tokens=TOKEN_PER_PIX*args.frames, action_all=action_all,windowsize = args.window_size, top_k=args.top_k, top_p=args.top_p)
else:
raise ValueError(f"Unknown accelerate algorithm {args.accelerate_algo}")
end_t = time.time()
all_generated_tokens.extend(outputs.tolist()[0])
new_length = len(all_generated_tokens)
time_costed = end_t - start_t
token_per_sec = new_length / time_costed
frame_per_sec = token_per_sec / TOKEN_PER_PIX
print(f"{new_length} token generated; cost {time_costed:.3f} second; {token_per_sec:.3f} token/sec {frame_per_sec:.3f} fps")
token2video(all_generated_tokens, model.tokenizer, str(output_path / demo_video), args.fps, device)
# return for evaluation
return_item = {
"time_costed": time_costed,
"token_num": new_length,
}
return return_item
if __name__ == '__main__':
args = get_args()
config = OmegaConf.load(args.config)
output_path = Path(args.output_dir)
precision_scope = autocast
os.makedirs(output_path, exist_ok=True)
model = load_model(config, args.model_ckpt, gpu=True, eval_mode=True)
print(f"[bold magenta][MINEWORLD][INFERENCE][/bold magenta] Load Model From {args.model_ckpt}")
# get accelearte algoritm
args.accelerate_algo = args.accelerate_algo.lower()
if args.accelerate_algo not in ACCELERATE_ALGO:
print(f"[bold red][Warning][/bold red] {args.accelerate_algo} is not in {ACCELERATE_ALGO}, use naive")
args.accelerate_algo = 'naive'
num_item = 0
for root, _, files in os.walk(args.data_root):
files = [f for f in files if f.endswith('.mp4')] # mp4 would not influence progress bar
files = sorted(files, key=lambda x: int(x.split('_')[1].split('.')[0]))
for file in tqdm(files):
return_item = lvm_generate(args, model, output_path,file)
num_item += 1
if num_item >= args.val_data_num:
print(f"[bold magenta][MINEWORLD][INFERENCE][/bold magenta] reach val data num limit {args.val_data_num}")
break
================================================
FILE: lvm.py
================================================
"""
Wrap the Huggingface Transformers Llama to PyTorch Lightning Module.
"""
import os
import sys
import inspect
import torch
from typing import Optional
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers import LlamaConfig
from utils import get_obj_from_str, instantiate_from_config
from diagonal_decoding import decode_one_token, decode_some_token, decode_n_tokens, decode_n_tokens_for_gradio, prefill, img_diagd_decode_n_tokens, video_diagd_decode_n_tokens, img_diagd_decode_n_token_for_gradio
torch.backends.cuda.matmul.allow_tf32 = False
logger = logging.get_logger(__name__)
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
if not (parentdir in sys.path):
sys.path.insert(0, parentdir)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
"""
Apply rotary position embeddings to query and key tensors.
Args:
q (torch.Tensor): Query tensor.
k (torch.Tensor): Key tensor.
cos (torch.Tensor): Cosine values.
sin (torch.Tensor): Sine values.
position_ids (torch.Tensor): Position IDs.
Returns:
torch.Tensor: Query and key tensors with rotary position embeddings applied.
"""
cos = cos[position_ids].unsqueeze(0).unsqueeze(2)
sin = sin[position_ids].unsqueeze(0).unsqueeze(2)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaLVM(torch.nn.Module):
def __init__(
self,
transformer_config,
model_class: str,
tokenizer_config = None,
):
super().__init__()
self.config = instantiate_from_config(transformer_config)
self.transformer = get_obj_from_str(model_class)(self.config)
if tokenizer_config is not None:
self.tokenizer = instantiate_from_config(tokenizer_config)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class LlamaRotaryEmbedding(nn.Module):
def __init__(
self,
device=None,
config: Optional[LlamaConfig] = None,
):
super().__init__()
self.rope_kwargs = {}
self.rope_type = "default"
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
self.max_position_embeddings = config.max_position_embeddings
inv_freq, _ = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._set_cos_sin_cache(
device=self.inv_freq.device,
dtype=torch.get_default_dtype(),
)
def _set_cos_sin_cache(self, device, dtype):
"""
Set the cosine and sine cache for positional embeddings.
Args:
seq_len (int): The sequence length.
device (str): The device on which the cache tensors will be stored.
dtype: The data type of the cache tensors.
"""
t = torch.arange(
self.max_position_embeddings, device=device, dtype=self.inv_freq.dtype
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer(
"cos_cached", emb.cos().to(dtype), persistent=False
)
self.register_buffer(
"sin_cached", emb.sin().to(dtype), persistent=False
)
def forward(self, x, seq_len=None):
"""
Forward pass of the LlamaRotaryEmbedding module.
Args:
x (torch.Tensor): Input tensor of shape [bs, num_attention_heads, seq_len, head_size].
seq_len (int): The sequence length. If greater than the cached length, the cache will be updated.
Returns:
tuple: A tuple containing two tensors, the cosine and sine embeddings, both of shape [1, 1, seq_len, dim].
"""
if seq_len > self.max_position_embeddings:
raise ValueError("seq length should less than max embedding")
return (
self.cos_cached[:seq_len, :].to(dtype=x.dtype),
self.sin_cached[:seq_len, :].to(dtype=x.dtype),
)
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
assert (self.head_dim * self.num_heads) == self.hidden_size, "hidden_size must be divisible by num_heads"
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.max_batch_size = getattr(config, "max_batch_size", 1)
self.init_kv_cache()
def init_kv_cache(self, dtype=torch.float16):
cache_shape = (self.max_batch_size, self.max_position_embeddings, self.num_key_value_heads, self.head_dim)
self.cache_k = torch.zeros(cache_shape, dtype=dtype).cuda()
self.cache_v = torch.zeros(cache_shape, dtype=dtype).cuda()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
positions_embedding = None,
):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
)
cos, sin = positions_embedding
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
self.cache_k[:bsz, position_ids] = key_states
self.cache_v[:bsz, position_ids] = value_states
key_states, value_states = (
self.cache_k[:bsz, :, :],
self.cache_v[:bsz, :, :],
)
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=2)
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=2)
query_states, key_states, value_states = map(lambda x: x.transpose(1, 2), (query_states, key_states, value_states))
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
).transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
positions_embedding = None,
):
"""
Forward pass for the LlamaDecoderLayer.
Args:
hidden_states (torch.FloatTensor): Input tensor of shape `(batch, seq_len, embed_dim)`.
attention_mask (torch.FloatTensor, optional): Attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
position_ids (torch.LongTensor, optional): Positional IDs tensor.
Returns:
Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: Tuple containing:
- hidden_states (torch.FloatTensor): Output tensor.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
positions_embedding=positions_embedding,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class LlamaModel(PreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.max_position_embedding = config.max_position_embeddings
self.causal_mask = torch.tril(
torch.ones(self.max_position_embedding, self.max_position_embedding, dtype=torch.bool)
).cuda()
self.post_init()
def _create_attention_mask(self, input_pos: Optional[torch.Tensor]):
"""
Creates an attention mask for the transformer layers.
Args:
input_pos[torch.Tensor]: The position of input sequence (used for inference only).
Returns:
Optional[torch.Tensor]: The attention mask, or None for causal mask.
"""
mask = self.causal_mask[input_pos]
return mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
):
if input_ids is None:
raise ValueError(
"decoder_input_ids is None"
)
hidden_states = self.embed_tokens(input_ids)
positions_embedding = self.rotary_emb(hidden_states, seq_len=self.max_position_embedding)
attention_mask = self._create_attention_mask(input_pos=position_ids)
for idx, decoder_layer in enumerate(self.layers):
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
positions_embedding=positions_embedding,
)
hidden_states = layer_outputs
hidden_states = self.norm(hidden_states)
return hidden_states
class LlamaForCausalLM(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
position_ids: Optional[torch.LongTensor] = None,
):
outputs = self.model(
input_ids=input_ids,
position_ids=position_ids,
)
logits = self.lm_head(outputs[:, :, :])
return logits
def refresh_kvcache(self):
for i in self.model.layers:
i.self_attn.init_kv_cache()
def naive_generate(self, input_ids, max_new_tokens, temperature=1.0, action_all=None, top_p=None, top_k=None):
self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
if action_all is not None:
input_ids = torch.cat([input_ids, action_all[0]], dim=-1)
position_ids = torch.arange(0, input_ids.shape[1], device="cuda")
next_token = self.prefill(
self,
input_ids=input_ids,
position_ids=position_ids,
temperature=temperature,
top_k = top_k,
top_p = top_p,
)
self.decode_one_token = torch.compile(decode_one_token, mode="max-autotune", fullgraph=True)
position_ids = torch.tensor([input_ids.shape[1]], dtype=torch.long, device="cuda")
generated_tokens = decode_n_tokens(
self,
input_ids = next_token.view(1, -1),
position_ids = position_ids,
num_generate_tokens = max_new_tokens - 1,
temperature = temperature,
decode_one_token_function=self.decode_one_token,
action=action_all,
top_p = top_p,
top_k = top_k,
)
return torch.cat(generated_tokens, dim=1)
def prefill_for_gradio(self, input_ids, temperature=1.0):
self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
last_pos = input_ids.shape[1]
position_ids = torch.arange(0, last_pos, device="cuda")
with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16):
next_token = self.prefill(
self,
input_ids=input_ids,
position_ids=position_ids,
temperature=temperature,
)
return next_token, last_pos
def decode_img_token_for_gradio(self, input_action, position_id, max_new_tokens, temperature=1.0):
self.decode_one_token = torch.compile(decode_one_token, mode="max-autotune", fullgraph=True)
# self.decode_one_token = decode_one_token
# WARNING
position_ids = torch.arange(position_id, position_id + input_action.shape[1], device="cuda")
with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16):
generated_tokens, position_id = decode_n_tokens_for_gradio(
self,
input_ids = input_action,
position_ids = position_ids,
num_generate_tokens = max_new_tokens,
temperature = temperature,
decode_one_token_function=self.decode_one_token,
)
# WARNING
return generated_tokens, position_id
def diagd_img_token_for_gradio(self, input_action, position_id, max_new_tokens, temperature=1.0, windowsize=2):
self.decode_some_token = torch.compile(decode_some_token, mode="max-autotune", fullgraph=True)
position_ids = torch.arange(position_id, position_id + input_action.shape[1], device="cuda")
with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16):
generated_tokens, position_id = img_diagd_decode_n_token_for_gradio(
self,
input_ids = input_action,
position_ids = position_ids,
num_generate_tokens = max_new_tokens,
temperature = temperature,
decode_some_token_function=self.decode_some_token,
windowsize = windowsize,
)
return generated_tokens, position_id
def img_diagd_generate(self, input_ids, max_new_tokens, temperature=1.0, action_all=None, windowsize=2, top_p=None, top_k=None):
self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
input_ids = torch.cat([input_ids, action_all[0]], dim=-1)
position_ids = torch.arange(0, input_ids.shape[1], device="cuda")
next_token = self.prefill(
self,
input_ids=input_ids,
position_ids=position_ids,
temperature=temperature,
top_k = top_k,
top_p = top_p,
)
self.decode_some_token = torch.compile(decode_some_token, mode="max-autotune", fullgraph=True)
position_ids = torch.tensor([input_ids.shape[1]], dtype=torch.long, device="cuda")
generated_tokens = img_diagd_decode_n_tokens(
self,
input_ids = next_token.view(1, -1),
position_ids = position_ids,
num_generate_tokens = max_new_tokens - 1,
temperature = temperature,
decode_some_token_function=self.decode_some_token,
windowsize = windowsize,
action=action_all,
prompt=input_ids,
top_k = top_k,
top_p = top_p,
)
return torch.cat(generated_tokens, dim=1)
def vid_diagd_generate(self, input_ids, max_new_tokens,windowsize=2, temperature=1.0, action_all=None,**kwargs):
self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
input_ids = torch.cat([input_ids, action_all[0]], dim=-1)
position_ids = torch.arange(0, input_ids.shape[1], device="cuda")
next_token = self.prefill(
self,
input_ids=input_ids,
position_ids=position_ids,
temperature=temperature,
)
self.decode_some_token = torch.compile(decode_some_token, mode="max-autotune", fullgraph=True)
# self.decode_some_token = decode_some_token
position_ids = torch.tensor([input_ids.shape[1]], dtype=torch.long, device="cuda")
generated_tokens = video_diagd_decode_n_tokens(
self,
input_ids = next_token.view(1, -1),
position_ids = position_ids,
num_generate_tokens = max_new_tokens - 1,
temperature = temperature,
decode_some_token_function=self.decode_some_token,
windowsize = windowsize,
action=action_all,
prompt=input_ids,
**kwargs
)
return torch.cat(generated_tokens, dim=1)
================================================
FILE: mcdataset.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import json
import attr
import collections
import numpy as np
from typing import Union, Dict
import torch
from utils import print0
# https://github.com/openai/Video-Pre-Training/blob/aed46b90e8db2332801feabd8be2de01f92c0ad2/run_inverse_dynamics_model.py#L17
KEYBOARD_BUTTON_MAPPING = {
"key.keyboard.escape" :"ESC",
"key.keyboard.s" :"back",
"key.keyboard.q" :"drop",
"key.keyboard.w" :"forward",
"key.keyboard.1" :"hotbar.1",
"key.keyboard.2" :"hotbar.2",
"key.keyboard.3" :"hotbar.3",
"key.keyboard.4" :"hotbar.4",
"key.keyboard.5" :"hotbar.5",
"key.keyboard.6" :"hotbar.6",
"key.keyboard.7" :"hotbar.7",
"key.keyboard.8" :"hotbar.8",
"key.keyboard.9" :"hotbar.9",
"key.keyboard.e" :"inventory",
"key.keyboard.space" :"jump",
"key.keyboard.a" :"left",
"key.keyboard.d" :"right",
"key.keyboard.left.shift" :"sneak",
"key.keyboard.left.control" :"sprint",
"key.keyboard.f" :"swapHands",
}
# https://github.com/openai/Video-Pre-Training/blob/aed46b90e8db2332801feabd8be2de01f92c0ad2/run_inverse_dynamics_model.py#L41
# Template action
NOOP_ACTION = {
"ESC": 0,
"back": 0,
"drop": 0,
"forward": 0,
"hotbar.1": 0,
"hotbar.2": 0,
"hotbar.3": 0,
"hotbar.4": 0,
"hotbar.5": 0,
"hotbar.6": 0,
"hotbar.7": 0,
"hotbar.8": 0,
"hotbar.9": 0,
"inventory": 0,
"jump": 0,
"left": 0,
"right": 0,
"sneak": 0,
"sprint": 0,
"swapHands": 0,
"camera": np.array([0, 0]), # [y, x]
"attack": 0,
"use": 0,
"pickItem": 0,
}
OASIS_ACTION_KEYS = [
"inventory",
"ESC",
"hotbar.1",
"hotbar.2",
"hotbar.3",
"hotbar.4",
"hotbar.5",
"hotbar.6",
"hotbar.7",
"hotbar.8",
"hotbar.9",
"forward",
"back",
"left",
"right",
"cameraX",
"cameraY",
"jump",
"sneak",
"sprint",
"swapHands",
"attack",
"use",
"pickItem",
"drop",
]
# Matches a number in the MineRL Java code regarding sensitivity
# This is for mapping from recorded sensitivity to the one used in the model
CAMERA_SCALER = 360.0 / 2400.0
# https://github.com/openai/Video-Pre-Training/blob/main/lib/actions.py#L8 with some modifications
class Buttons:
# 14 in total without hotbar and camera
ATTACK = "attack"
BACK = "back"
FORWARD = "forward"
JUMP = "jump"
LEFT = "left"
RIGHT = "right"
SNEAK = "sneak"
SPRINT = "sprint"
USE = "use"
DROP = "drop"
INVENTORY = "inventory"
# added by Yang
ESC = "ESC"
SWAPHANDS = "swapHands"
PICKITEM = "pickItem"
ALL = [
USE,
ATTACK,
FORWARD,
BACK,
LEFT,
RIGHT,
JUMP,
SNEAK,
SPRINT,
DROP,
SWAPHANDS,
PICKITEM,
INVENTORY,
ESC,
] + [f"hotbar.{i}" for i in range(1, 10)]
class QuantizationScheme:
LINEAR = "linear"
MU_LAW = "mu_law"
# https://github.com/openai/Video-Pre-Training/blob/main/lib/actions.py#L49
@attr.s(auto_attribs=True)
class CameraQuantizer:
"""
A camera quantizer that discretizes and undiscretizes a continuous camera input with y (pitch) and x (yaw) components.
Parameters:
- camera_binsize: The size of the bins used for quantization. In case of mu-law quantization, it corresponds to the average binsize.
- camera_maxval: The maximum value of the camera action.
- quantization_scheme: The quantization scheme to use. Currently, two quantization schemes are supported:
- Linear quantization (default): Camera actions are split uniformly into discrete bins
- Mu-law quantization: Transforms the camera action using mu-law encoding (https://en.wikipedia.org/wiki/%CE%9C-law_algorithm)
followed by the same quantization scheme used by the linear scheme.
- mu: Mu is the parameter that defines the curvature of the mu-law encoding. Higher values of
mu will result in a sharper transition near zero. Below are some reference values listed
for choosing mu given a constant maxval and a desired max_precision value.
maxval = 10 | max_precision = 0.5 | μ ≈ 2.93826
maxval = 10 | max_precision = 0.4 | μ ≈ 4.80939
maxval = 10 | max_precision = 0.25 | μ ≈ 11.4887
maxval = 20 | max_precision = 0.5 | μ ≈ 2.7
maxval = 20 | max_precision = 0.4 | μ ≈ 4.39768
maxval = 20 | max_precision = 0.25 | μ ≈ 10.3194
maxval = 40 | max_precision = 0.5 | μ ≈ 2.60780
maxval = 40 | max_precision = 0.4 | μ ≈ 4.21554
maxval = 40 | max_precision = 0.25 | μ ≈ 9.81152
"""
camera_maxval: int
camera_binsize: int
quantization_scheme: str = attr.ib(
default=QuantizationScheme.LINEAR,
validator=attr.validators.in_([QuantizationScheme.LINEAR, QuantizationScheme.MU_LAW]),
)
mu: float = attr.ib(default=5)
def discretize(self, xy):
xy = np.clip(xy, -self.camera_maxval, self.camera_maxval)
if self.quantization_scheme == QuantizationScheme.MU_LAW:
xy = xy / self.camera_maxval
v_encode = np.sign(xy) * (np.log(1.0 + self.mu * np.abs(xy)) / np.log(1.0 + self.mu))
v_encode *= self.camera_maxval
xy = v_encode
# Quantize using linear scheme
return np.round((xy + self.camera_maxval) / self.camera_binsize).astype(np.int64)
def undiscretize(self, xy):
xy = xy * self.camera_binsize - self.camera_maxval
if self.quantization_scheme == QuantizationScheme.MU_LAW:
xy = xy / self.camera_maxval
v_decode = np.sign(xy) * (1.0 / self.mu) * ((1.0 + self.mu) ** np.abs(xy) - 1.0)
v_decode *= self.camera_maxval
xy = v_decode
return xy
class MCDataset(torch.utils.data.Dataset):
"""
Dataset for Minecraft.
"""
def __init__(self,
action_length: int = 11, # including bos and eos
camera_binsize: int = 9, # 2 in vpt
camera_maxval: int = 90, # 10 in vpt
camera_mu: float = 11.4887, # 10 in vpt
quantization_scheme: str = "mu_law",
):
self.action_length = action_length
self.camera_quantizer = CameraQuantizer(
camera_binsize=camera_binsize,
camera_maxval=camera_maxval,
mu=camera_mu,
quantization_scheme=quantization_scheme,
)
def json_action_to_env_action(self, json_action):
"""
https://github.com/openai/Video-Pre-Training/blob/aed46b90e8db2332801feabd8be2de01f92c0ad2/run_inverse_dynamics_model.py#L80
Converts a json action into a MineRL action.
Returns (minerl_action, is_null_action)
"""
# This might be slow...
env_action = NOOP_ACTION.copy()
# As a safeguard, make camera action again so we do not override anything
env_action["camera"] = np.array([0, 0])
is_null_action = True
keyboard_keys = json_action["keyboard"]["keys"]
for key in keyboard_keys:
# You can have keys that we do not use, so just skip them
# NOTE in original training code, ESC was removed and replaced with
# "inventory" action if GUI was open.
# Not doing it here, as BASALT uses ESC to quit the game.
if key in KEYBOARD_BUTTON_MAPPING:
env_action[KEYBOARD_BUTTON_MAPPING[key]] = 1
is_null_action = False
mouse = json_action["mouse"]
camera_action = env_action["camera"]
camera_action[0] = mouse["dy"] * CAMERA_SCALER
camera_action[1] = mouse["dx"] * CAMERA_SCALER
if mouse["dx"] != 0 or mouse["dy"] != 0:
is_null_action = False
else:
if abs(camera_action[0]) > 180:
camera_action[0] = 0
if abs(camera_action[1]) > 180:
camera_action[1] = 0
mouse_buttons = mouse["buttons"]
if 0 in mouse_buttons:
env_action["attack"] = 1
is_null_action = False
if 1 in mouse_buttons:
env_action["use"] = 1
is_null_action = False
if 2 in mouse_buttons:
env_action["pickItem"] = 1
is_null_action = False
# added by Yang
# if two confictory actions are pressed, remove them
if env_action["forward"] == 1 and env_action["back"] == 1:
env_action["forward"] = 0
env_action["back"] = 0
if env_action["left"] == 1 and env_action["right"] == 1:
env_action["left"] = 0
env_action["right"] = 0
if env_action["jump"] == 1 and env_action["sneak"] == 1:
env_action["jump"] = 0
env_action["sneak"] = 0
if env_action["sprint"] == 1 and env_action["sneak"] == 1:
env_action["sprint"] = 0
env_action["sneak"] = 0
if env_action["attack"] == 1 and env_action["use"] == 1:
env_action["attack"] = 0
env_action["use"] = 0
# remove inventory and ESC action
if env_action["inventory"] == 1:
is_null_action = True
if env_action["ESC"] == 1:
is_null_action = True
return env_action, is_null_action
def make_action_vocab(self,
num_cam_bins: int = 21,
action_vocab_offset: int = 0,
verbose: bool = False):
action_vocab = collections.OrderedDict()
# 14 actions and hotbar.1-9
for i, action in enumerate(Buttons.ALL):
action_vocab[action] = i
# camera 0
for i in range(num_cam_bins):
action_vocab[f"cam_0_{i}"] = len(Buttons.ALL) + i
# camera 1
for i in range(num_cam_bins):
action_vocab[f"cam_1_{i}"] = len(Buttons.ALL) + num_cam_bins + i
# bos, null, eos
action_vocab[""] = len(Buttons.ALL) + 2 * num_cam_bins
action_vocab[""] = len(Buttons.ALL) + 2 * num_cam_bins + 1
action_vocab[""] = len(Buttons.ALL) + 2 * num_cam_bins + 2
if action_vocab_offset > 0:
action_vocab = {k: v + action_vocab_offset for k, v in action_vocab.items()}
if verbose:
print0(f"[bold yellow]\[MCDataset][/bold yellow] Action Vocab: {action_vocab}")
self.action_vocab = action_vocab
# return action_vocab
def _handle_conflict_action_index(self,
action_dict: Dict[str, Union[int, np.ndarray]],
key1: str,
key2: str,
null_key: str,
verbose: bool = False):
if action_dict[key1] == 1 and action_dict[key2] == 1:
if verbose:
print0(f"[bold yellow]\[MCDataset][/bold yellow] {key1} and {key2} are both pressed")
return self.action_vocab[null_key]
elif action_dict[key1] == 1:
return self.action_vocab[key1]
elif action_dict[key2] == 1:
return self.action_vocab[key2]
else:
return self.action_vocab[null_key]
def get_action_index_from_actiondict(self,
action_dict: Dict[str, Union[int, np.ndarray]],
action_vocab_offset: int = 0,
verbose: bool = False):
if not hasattr(self, "action_vocab"):
self.make_action_vocab(action_vocab_offset=action_vocab_offset, verbose=verbose)
# action_list = [boa, camy, camx, hotbar, fore_back, left_right, sprint_sneak, use_attack, jump, drop_pick, eoa]
# 11 actions
action_list = [self.action_vocab[""]] * self.action_length
# 0 & 10
action_list[0] = self.action_vocab[""]
action_list[-1] = self.action_vocab[""]
camera_action = action_dict["camera"]
assert len(camera_action) == 2, f"[MCDataset] camera_action length is not 2: {camera_action}"
# camera_action should be numpy array
if not isinstance(camera_action, np.ndarray):
camera_action = np.array(camera_action)
camera_action = self.camera_quantizer.discretize(camera_action)
# 1 & 2
action_list[1] = self.action_vocab[f"cam_0_{camera_action[0]}"]
action_list[2] = self.action_vocab[f"cam_1_{camera_action[1]}"]
# 3
for i in range(1, 10):
if f"hotbar.{i}" in action_dict and action_dict[f"hotbar.{i}"] == 1:
action_list[3] = self.action_vocab[f"hotbar.{i}"]
break
# 4 forward/back
action_list[4] = self._handle_conflict_action_index(action_dict, "forward", "back", "", verbose=verbose)
# 5 left/right
action_list[5] = self._handle_conflict_action_index(action_dict, "left", "right", "", verbose=verbose)
# 6 sprint/sneak
action_list[6] = self._handle_conflict_action_index(action_dict, "sprint", "sneak", "", verbose=verbose)
# 7 use/attack
action_list[7] = self._handle_conflict_action_index(action_dict, "use", "attack", "", verbose=verbose)
# 8 jump
action_list[8] = self.action_vocab["jump"] if action_dict["jump"] == 1 else self.action_vocab[""]
# 9 drop/pick
action_list[9] = self._handle_conflict_action_index(action_dict, "drop", "pickItem", "", verbose=verbose)
if verbose:
print0(f"[bold yellow]\[MCDataset][/bold yellow] Action List: {action_list}")
return action_list
def read_jsonl(self, jsonl_path: str):
assert os.path.isfile(jsonl_path), f"[MCDataset] {jsonl_path} does not exist"
# read jsonl
# https://github.com/openai/Video-Pre-Training/blob/main/data_loader.py#L76
try:
with open(jsonl_path) as json_file:
json_lines = json_file.readlines()
json_data = "[" + ",".join(json_lines) + "]"
json_data = json.loads(json_data)
except Exception as e:
print0(f"[bold yellow]\[MCDataset][/bold yellow] {jsonl_path} cannot be read: {e}")
return None
return json_data
================================================
FILE: metrics/IDM/inverse_dynamics_model.py
================================================
# Borrowed from VPT (https://github.com/openai/Video-Pre-Training)
import numpy as np
import torch as th
import cv2
from gym3.types import DictType
from gym import spaces
from tqdm import tqdm
import os
from argparse import ArgumentParser
import pickle
import cv2
import json
from lib.action_mapping import IDMActionMapping
from lib.actions import ActionTransformer
from lib.policy import InverseActionPolicy
from lib.torch_util import default_device_type, set_default_torch_device
from sklearn.metrics import precision_score, recall_score, f1_score
# Hardcoded settings
AGENT_RESOLUTION = (128, 128)
safe_globals = {"array": np.array}
def resize_image(img, target_resolution):
# For your sanity, do not resize with any function than INTER_LINEAR
img = cv2.resize(img, target_resolution, interpolation=cv2.INTER_LINEAR)
return img
ACTION_TRANSFORMER_KWARGS = dict(
camera_binsize=2,
camera_maxval=10,
camera_mu=10,
camera_quantization_scheme="mu_law",
)
class IDMAgent:
"""
Sugarcoating on the inverse dynamics model (IDM) used to predict actions Minecraft players take in videos.
Functionally same as MineRLAgent.
"""
def __init__(self, idm_net_kwargs, pi_head_kwargs, device=None):
if device is None:
device = default_device_type()
self.device = th.device(device)
# Set the default torch device for underlying code as well
set_default_torch_device(self.device)
self.action_mapper = IDMActionMapping(n_camera_bins=11)
action_space = self.action_mapper.get_action_space_update()
action_space = DictType(**action_space)
self.action_transformer = ActionTransformer(**ACTION_TRANSFORMER_KWARGS)
idm_policy_kwargs = dict(idm_net_kwargs=idm_net_kwargs, pi_head_kwargs=pi_head_kwargs, action_space=action_space)
self.policy = InverseActionPolicy(**idm_policy_kwargs).to(device)
self.hidden_state = self.policy.initial_state(1)
self._dummy_first = th.from_numpy(np.array((False,))).to(device)
def load_weights(self, path):
"""Load model weights from a path, and reset hidden state"""
self.policy.load_state_dict(th.load(path, map_location=self.device), strict=False)
self.reset()
def reset(self):
"""Reset agent to initial state (i.e., reset hidden state)"""
self.hidden_state = self.policy.initial_state(1)
def _video_obs_to_agent(self, video_frames):
imgs = [resize_image(frame, AGENT_RESOLUTION) for frame in video_frames]
# Add time and batch dim
imgs = np.stack(imgs)[None]
agent_input = {"img": th.from_numpy(imgs).to(self.device)}
return agent_input
def _agent_action_to_env(self, agent_action):
"""Turn output from policy into action for MineRL"""
# This is quite important step (for some reason).
# For the sake of your sanity, remember to do this step (manual conversion to numpy)
# before proceeding. Otherwise, your agent might be a little derp.
action = {
"buttons": agent_action["buttons"].cpu().numpy(),
"camera": agent_action["camera"].cpu().numpy()
}
minerl_action = self.action_mapper.to_factored(action)
minerl_action_transformed = self.action_transformer.policy2env(minerl_action)
return minerl_action_transformed
def predict_actions(self, video_frames):
"""
Predict actions for a sequence of frames.
`video_frames` should be of shape (N, H, W, C).
Returns MineRL action dict, where each action head
has shape (N, ...).
Agent's hidden state is tracked internally. To reset it,
call `reset()`.
"""
agent_input = self._video_obs_to_agent(video_frames)
# The "first" argument could be used to reset tell episode
# boundaries, but we are only using this for predicting (for now),
# so we do not hassle with it yet.
dummy_first = th.zeros((video_frames.shape[0], 1)).to(self.device)
predicted_actions, self.hidden_state, _ = self.policy.predict(
agent_input, first=dummy_first, state_in=self.hidden_state,
deterministic=True
)
predicted_minerl_action = self._agent_action_to_env(predicted_actions)
return predicted_minerl_action
# NOTE: this is _not_ the original code of IDM!
# As such, while it is close and seems to function well,
# its performance might be bit off from what is reported
# in the paper.
ENV_KWARGS = dict(
fov_range=[70, 70],
frameskip=1,
gamma_range=[2, 2],
guiscale_range=[1, 1],
resolution=[640, 360],
cursor_size_range=[16.0, 16.0],
)
KEYBOARD_BUTTON_MAPPING = {
"key.keyboard.escape" :"ESC",
"key.keyboard.s" :"back",
"key.keyboard.q" :"drop",
"key.keyboard.w" :"forward",
"key.keyboard.1" :"hotbar.1",
"key.keyboard.2" :"hotbar.2",
"key.keyboard.3" :"hotbar.3",
"key.keyboard.4" :"hotbar.4",
"key.keyboard.5" :"hotbar.5",
"key.keyboard.6" :"hotbar.6",
"key.keyboard.7" :"hotbar.7",
"key.keyboard.8" :"hotbar.8",
"key.keyboard.9" :"hotbar.9",
"key.keyboard.e" :"inventory",
"key.keyboard.space" :"jump",
"key.keyboard.a" :"left",
"key.keyboard.d" :"right",
"key.keyboard.left.shift" :"sneak",
"key.keyboard.left.control" :"sprint",
"key.keyboard.f" :"swapHands",
}
# Template action
NOOP_ACTION = {
"ESC": 0,
"back": 0,
"drop": 0,
"forward": 0,
"hotbar.1": 0,
"hotbar.2": 0,
"hotbar.3": 0,
"hotbar.4": 0,
"hotbar.5": 0,
"hotbar.6": 0,
"hotbar.7": 0,
"hotbar.8": 0,
"hotbar.9": 0,
"inventory": 0,
"jump": 0,
"left": 0,
"right": 0,
"sneak": 0,
"sprint": 0,
"swapHands": 0,
"camera": np.array([0, 0]),
"attack": 0,
"use": 0,
"pickItem": 0,
}
# Matches a number in the MineRL Java code regarding sensitivity
# This is for mapping from recorded sensitivity to the one used in the model
CAMERA_SCALER = 360.0 / 2400.0
def json_action_to_env_action(json_action):
"""
Converts a json action into a MineRL action.
Returns (minerl_action, is_null_action)
"""
if "ESC" in json_action:
return json_action, False
# This might be slow...
env_action = NOOP_ACTION.copy()
# As a safeguard, make camera action again so we do not override anything
env_action["camera"] = np.array([0, 0])
is_null_action = True
keyboard_keys = json_action["keyboard"]["keys"]
for key in keyboard_keys:
# You can have keys that we do not use, so just skip them
# NOTE in original training code, ESC was removed and replaced with
# "inventory" action if GUI was open.
# Not doing it here, as BASALT uses ESC to quit the game.
if key in KEYBOARD_BUTTON_MAPPING:
env_action[KEYBOARD_BUTTON_MAPPING[key]] = 1
is_null_action = False
mouse = json_action["mouse"]
camera_action = env_action["camera"]
camera_action[0] = mouse["dy"] * CAMERA_SCALER
camera_action[1] = mouse["dx"] * CAMERA_SCALER
if mouse["dx"] != 0 or mouse["dy"] != 0:
is_null_action = False
else:
if abs(camera_action[0]) > 180:
camera_action[0] = 0
if abs(camera_action[1]) > 180:
camera_action[1] = 0
mouse_buttons = mouse["buttons"]
if 0 in mouse_buttons:
env_action["attack"] = 1
is_null_action = False
if 1 in mouse_buttons:
env_action["use"] = 1
is_null_action = False
if 2 in mouse_buttons:
env_action["pickItem"] = 1
is_null_action = False
return env_action, is_null_action
def load_action_jsonl(json_path):
with open(json_path) as json_file:
json_lines = json_file.readlines()
json_data = "[" + ",".join(json_lines) + "]"
json_data = json.loads(json_data)
return json_data
# loss on frame - avg on video - avg on dataset
def evaluate_IDM_quality(model, weights,jsonl_folder, video_folder, infer_demo_num, n_frames, output_file):
"""
Evaluate the quality of a IDM model on a dataset of videos.
Args:
video_folder (str): Path to the folder containing videos.
model (str): Path to the '.model' file to be loaded.
weights (str): Path to the '.weights' file to be loaded.
n_batches (int): Number of batches to process.
n_frames (int): Number of frames to process at a time.
"""
## set up IDM model
agent_parameters = pickle.load(open(model, "rb"))
net_kwargs = agent_parameters["model"]["args"]["net"]["args"]
pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"]
pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"])
# pi_head_kwargs["temperature"] = 1.0
agent = IDMAgent(idm_net_kwargs=net_kwargs, pi_head_kwargs=pi_head_kwargs)
agent.load_weights(weights)
# Load video files
video_files = os.listdir(video_folder)
video_files = [f for f in video_files if f.endswith(".mp4")]
video_files = sorted(video_files)
video_files = [os.path.join(video_folder, f) for f in video_files]
eval_num = min(500,len(video_files))
video_files = video_files[:eval_num]
dataset_labels = {}
camera_loss_list = []
for video_file in tqdm(video_files):
json_file = os.path.join(jsonl_folder,os.path.basename(video_file).replace(".mp4",".jsonl"))
# old implementation
# action_loss,video_avg_loss,predicted_actions_list = eval_1_video(agent, video_file, json_file, infer_demo_num, n_frames)
# load predicted actions and recorded actions
predicted_actions,recorded_actions = idm_prediction(agent, video_file,json_file, infer_demo_num, n_frames)
# construct labels
subtasks_labels = define_exclusive_classification_task(predicted_actions,recorded_actions,calculate_hot_bar = False)
for key in subtasks_labels:
if key not in dataset_labels:
dataset_labels[key] = {"pred_labels":[] , "rec_labels":[], "class_num":0}
dataset_labels[key]["pred_labels"].append(subtasks_labels[key]["pred_labels"])# array
dataset_labels[key]["rec_labels"].append(subtasks_labels[key]["rec_labels"]) # array
dataset_labels[key]["class_num"] = subtasks_labels[key]["class_num"]
camera_loss_list.append(camera_loss(predicted_actions,recorded_actions)["camera_bin_loss"])
dataset_results ={}
for key in dataset_labels:
pred_labels = np.stack(dataset_labels[key]["pred_labels"]).flatten() # [num_videos , num_frames] -> [video_num x frame_num]
rec_labels = np.stack(dataset_labels[key]["rec_labels"]).flatten() # [num_videos , num_frames] -> [video_num x frame_num]
dataset_results[key]=classification_metric(pred_labels, rec_labels, dataset_labels[key]["class_num"])
# import pdb;pdb.set_trace()
metric_mean_on_task = {}
metrics = ['precision_micro', 'recall_micro', 'f1_micro', 'precision_macro', 'recall_macro', 'f1_macro']
tasks = dataset_results.keys()
for key in metrics:
if key == "class_num":
continue
metric_mean_on_task[key] = np.mean([dataset_results[task][key] for task in tasks])
dataset_results["metric_mean_on_task"] = metric_mean_on_task
dataset_results["metric_mean_on_task"]["camera_loss"] = np.mean(camera_loss_list)
## change all keys into str
dataset_results = {str(k): v for k, v in dataset_results.items()}
print(dataset_results)
print("===========================================")
print(f"{output_file} IDM Metric: {metric_mean_on_task}")
with open(output_file, 'w') as f:
f.write(json.dumps(dataset_results,indent=4) + "\n")
def construct_classification_labels(idm_actions:dict[str, list[int]],action_name_keys: list[int],num_class:int) -> list[int]:
"""
convert original predicted actions to classification labels
"""
# construct a one-hot vector string to int label
vec2cls = {"0"*(num_class-1):0}
for i in range(num_class-1):
key = "0"*i + "1" + "0"*(num_class-2-i)
vec2cls[key] = i+1
# print(vec2cls)
vec2cls['1'*(num_class-1)] = 0 # do all equal not do
# vec2cls = {"00":0,"10":1,"01":2} # tested for class_num = 2
num_labels = idm_actions[action_name_keys[0]].size # assert same length: video_num x frame_per_video
# if not single in first dim, we should perform flattn
# construct one-hot vector
idm_action_string = [[str(int(i)) for i in idm_actions[action_name].flatten()] for action_name in action_name_keys]
try:
labels = [vec2cls["".join([idm_action_string[j][i] for j in range(num_class-1)])] for i in range(num_labels)]
except:
conflicts_num = sum([ i=='1' and j=='1' for i,j in zip(idm_action_string[0],idm_action_string[1])])
print(f"detect conflict prediction: {conflicts_num}")
return None
labels = np.array(labels)
return labels
def define_exclusive_classification_task(predicted_actions:dict,recorded_actions:dict,calculate_hot_bar = False) -> dict:
subtasks = {"multi_class":[("back","forward"),# 01,00,10,
("left","right"),
("sneak","sprint"),
],
"binary_class":["use","attack","jump","drop"]
}
if calculate_hot_bar:
subtasks["multi_class"]=[("hotbar.1","hotbar.2","hotbar.3","hotbar.4","hotbar.5","hotbar.6","hotbar.7","hotbar.8","hotbar.9")]
subtasks_labels = {}
for class_pair in subtasks["multi_class"]:
class_num = len(class_pair) + 1 # len = 2 has 00 01 10
# convert to strings
# convert to classification
pred_labels = construct_classification_labels(predicted_actions, class_pair, class_num)
rec_labels = construct_classification_labels(recorded_actions, class_pair, class_num)
if pred_labels is None or rec_labels is None:
print(f"detect conflict prediction: {pred_labels} and {rec_labels}")
continue
subtasks_labels[class_pair] = {"class_num":class_num,
"pred_labels":pred_labels,
"rec_labels":rec_labels
}
for binary_task in subtasks["binary_class"]:
pred_labels = predicted_actions[binary_task]
rec_labels = recorded_actions[binary_task]
subtasks_labels[binary_task] = {"class_num":2,
"pred_labels":pred_labels,
"rec_labels":rec_labels
}
return subtasks_labels
def classification_metric(pred_labels, rec_labels, class_num):
## compute macro and micro score for both tri classification and binary classification
## the difference between macro and micro precision and binary precision for binary task is :
## the binary precision only compute label with 1 ; but micro and marco compute 0, 1 and then average them
## to align with tri-classification we use average="macro" and average="micro"
precision_micro = precision_score(rec_labels, pred_labels, average="micro", zero_division=0)
recall_micro = recall_score(rec_labels, pred_labels, average="micro", zero_division=0)
f1_micro = f1_score(rec_labels, pred_labels, average="micro", zero_division=0)
precision_macro = precision_score(rec_labels, pred_labels, average="macro", zero_division=0)
recall_macro = recall_score(rec_labels, pred_labels, average="macro", zero_division=0)
f1_macro = f1_score(rec_labels, pred_labels, average="macro", zero_division=0)
return {
"precision_micro": precision_micro,
"recall_micro": recall_micro,
"f1_micro": f1_micro,
"precision_macro": precision_macro,
"recall_macro": recall_macro,
"f1_macro": f1_macro,
"class_num": class_num
}
def aggregate_actions(actions:list) -> dict:
return_dict = {}
for action in actions:
for key in action:
if key not in return_dict:
return_dict[key] = []
return_dict[key].append(action[key])
for key in return_dict:
return_dict[key] = np.array(return_dict[key]).reshape(-1)
return return_dict
def idm_prediction(agent, video_path,json_path, infer_demo_num, n_frames):
th.cuda.empty_cache()
full_json_data = load_action_jsonl(json_path)
json_data = full_json_data[infer_demo_num:infer_demo_num+n_frames]
recorded_actions = [json_action_to_env_action(i)[0] for i in json_data]
recorded_actions = aggregate_actions(recorded_actions)
frames = []
cap = cv2.VideoCapture(video_path)
for _ in range(n_frames):
ret, frame = cap.read()
if not ret:
print(f"[Error] loading frames in {video_path} returing {_}")
return None,None
# BGR -> RGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
frames = np.stack(frames)
predicted_actions = agent.predict_actions(frames)
for key in predicted_actions:
if key == "camera":
continue
predicted_actions[key] = np.array(predicted_actions[key]).reshape(-1)
return predicted_actions,recorded_actions
def camera_loss(predicted_actions,recorded_actions):
from lib.actions import CameraQuantizer
cam_quantizer = CameraQuantizer(
camera_binsize=2,
camera_maxval=10,
mu=10,
quantization_scheme="mu_law")
# import pdb;pdb.set_trace()
cam_pred_token=cam_quantizer.discretize(predicted_actions['camera'].reshape(-1))
cam_gt_token =cam_quantizer.discretize(np.array(recorded_actions['camera']))
camera_bin_loss = np.abs(cam_pred_token-cam_gt_token).mean()
return {
"camera_bin_loss":camera_bin_loss
}
if __name__ == "__main__":
parser = ArgumentParser("Evaluate IDM quality for MC-LVM ")
parser.add_argument("--weights", type=str, required=True, help="[IDM model config] Path to the '.weights' file to be loaded.")
parser.add_argument("--model", type=str, required=True, help="[IDM model config] Path to the '.model' file to be loaded.")
parser.add_argument("--jsonl-path", type=str, required=True, help="[Eval Config] Path to .jsonl contains actions.")
parser.add_argument("--video-path", type=str, required=True, help="[Eval Config] Path to a .mp4 file.")
parser.add_argument("--infer-demo-num", type=int, default=0, help="[Inference Config] Number of frames to skip before starting evaluation.")
parser.add_argument("--n-frames", type=int, default=32, help="[Inference Config] Number of frames to generation.")
parser.add_argument("--output-file", type=str, default="[Eval Config] output/action_loss.jsonl", help="[Eval Config] Path to save the action loss.")
args = parser.parse_args()
os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
evaluate_IDM_quality(args.model, args.weights,args.jsonl_path ,args.video_path, args.infer_demo_num,args.n_frames,args.output_file)
================================================
FILE: metrics/IDM/lib/__init__.py
================================================
================================================
FILE: metrics/IDM/lib/action_head.py
================================================
import logging
from typing import Any, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from gym3.types import DictType, Discrete, Real, TensorType, ValType
LOG0 = -100
def fan_in_linear(module: nn.Module, scale=1.0, bias=True):
"""Fan-in init"""
module.weight.data *= scale / module.weight.norm(dim=1, p=2, keepdim=True)
if bias:
module.bias.data *= 0
class ActionHead(nn.Module):
"""Abstract base class for action heads compatible with forc"""
def forward(self, input_data: torch.Tensor) -> Any:
"""
Just a forward pass through this head
:returns pd_params - parameters describing the probability distribution
"""
raise NotImplementedError
def logprob(self, action_sample: torch.Tensor, pd_params: torch.Tensor) -> torch.Tensor:
"""Logartithm of probability of sampling `action_sample` from a probability described by `pd_params`"""
raise NotImplementedError
def entropy(self, pd_params: torch.Tensor) -> torch.Tensor:
"""Entropy of this distribution"""
raise NotImplementedError
def sample(self, pd_params: torch.Tensor, deterministic: bool = False) -> Any:
"""
Draw a sample from probability distribution given by those params
:param pd_params Parameters of a probability distribution
:param deterministic Whether to return a stochastic sample or deterministic mode of a distribution
"""
raise NotImplementedError
def kl_divergence(self, params_q: torch.Tensor, params_p: torch.Tensor) -> torch.Tensor:
"""KL divergence between two distribution described by these two params"""
raise NotImplementedError
class DiagGaussianActionHead(ActionHead):
"""
Action head where actions are normally distributed uncorrelated variables with specific means and variances.
Means are calculated directly from the network while standard deviations are a parameter of this module
"""
LOG2PI = np.log(2.0 * np.pi)
def __init__(self, input_dim: int, num_dimensions: int):
super().__init__()
self.input_dim = input_dim
self.num_dimensions = num_dimensions
self.linear_layer = nn.Linear(input_dim, num_dimensions)
self.log_std = nn.Parameter(torch.zeros(num_dimensions), requires_grad=True)
def reset_parameters(self):
init.orthogonal_(self.linear_layer.weight, gain=0.01)
init.constant_(self.linear_layer.bias, 0.0)
def forward(self, input_data: torch.Tensor, mask=None) -> torch.Tensor:
assert not mask, "Can not use a mask in a gaussian action head"
means = self.linear_layer(input_data)
# Unsqueeze many times to get to the same shape
logstd = self.log_std[(None,) * (len(means.shape) - 1)]
mean_view, logstd = torch.broadcast_tensors(means, logstd)
return torch.stack([mean_view, logstd], dim=-1)
def logprob(self, action_sample: torch.Tensor, pd_params: torch.Tensor) -> torch.Tensor:
"""Log-likelihood"""
means = pd_params[..., 0]
log_std = pd_params[..., 1]
std = torch.exp(log_std)
z_score = (action_sample - means) / std
return -(0.5 * ((z_score ** 2 + self.LOG2PI).sum(dim=-1)) + log_std.sum(dim=-1))
def entropy(self, pd_params: torch.Tensor) -> torch.Tensor:
"""
Categorical distribution entropy calculation - sum probs * log(probs).
In case of diagonal gaussian distribution - 1/2 log(2 pi e sigma^2)
"""
log_std = pd_params[..., 1]
return (log_std + 0.5 * (self.LOG2PI + 1)).sum(dim=-1)
def sample(self, pd_params: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
means = pd_params[..., 0]
log_std = pd_params[..., 1]
if deterministic:
return means
else:
return torch.randn_like(means) * torch.exp(log_std) + means
def kl_divergence(self, params_q: torch.Tensor, params_p: torch.Tensor) -> torch.Tensor:
"""
Categorical distribution KL divergence calculation
KL(Q || P) = sum Q_i log (Q_i / P_i)
Formula is:
log(sigma_p) - log(sigma_q) + (sigma_q^2 + (mu_q - mu_p)^2))/(2 * sigma_p^2)
"""
means_q = params_q[..., 0]
log_std_q = params_q[..., 1]
means_p = params_p[..., 0]
log_std_p = params_p[..., 1]
std_q = torch.exp(log_std_q)
std_p = torch.exp(log_std_p)
kl_div = log_std_p - log_std_q + (std_q ** 2 + (means_q - means_p) ** 2) / (2.0 * std_p ** 2) - 0.5
return kl_div.sum(dim=-1, keepdim=True)
class CategoricalActionHead(ActionHead):
"""Action head with categorical actions"""
def __init__(
self, input_dim: int, shape: Tuple[int], num_actions: int, builtin_linear_layer: bool = True, temperature: float = 1.0
):
super().__init__()
self.input_dim = input_dim
self.num_actions = num_actions
self.output_shape = shape + (num_actions,)
self.temperature = temperature
if builtin_linear_layer:
self.linear_layer = nn.Linear(input_dim, np.prod(self.output_shape))
else:
assert (
input_dim == num_actions
), f"If input_dim ({input_dim}) != num_actions ({num_actions}), you need a linear layer to convert them."
self.linear_layer = None
def reset_parameters(self):
if self.linear_layer is not None:
init.orthogonal_(self.linear_layer.weight, gain=0.01)
init.constant_(self.linear_layer.bias, 0.0)
finit.fan_in_linear(self.linear_layer, scale=0.01)
def forward(self, input_data: torch.Tensor, mask=None) -> Any:
if self.linear_layer is not None:
flat_out = self.linear_layer(input_data)
else:
flat_out = input_data
shaped_out = flat_out.reshape(flat_out.shape[:-1] + self.output_shape)
shaped_out /= self.temperature
if mask is not None:
shaped_out[~mask] = LOG0
# Convert to float32 to avoid RuntimeError: "log_softmax_lastdim_kernel_impl" not implemented for 'Half'
return F.log_softmax(shaped_out.float(), dim=-1)
def logprob(self, actions: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
value = actions.long().unsqueeze(-1)
value, log_pmf = torch.broadcast_tensors(value, logits)
value = value[..., :1]
result = log_pmf.gather(-1, value).squeeze(-1)
# result is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it.
for _ in self.output_shape[:-1]:
result = result.sum(dim=-1)
return result
def entropy(self, logits: torch.Tensor) -> torch.Tensor:
"""Categorical distribution entropy calculation - sum probs * log(probs)"""
probs = torch.exp(logits)
entropy = -torch.sum(probs * logits, dim=-1)
# entropy is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it.
for _ in self.output_shape[:-1]:
entropy = entropy.sum(dim=-1)
return entropy
def sample(self, logits: torch.Tensor, deterministic: bool = False) -> Any:
if deterministic:
return torch.argmax(logits, dim=-1)
else:
# Gumbel-Softmax trick.
u = torch.rand_like(logits)
# In float16, if you have around 2^{float_mantissa_bits} logits, sometimes you'll sample 1.0
# Then the log(-log(1.0)) will give -inf when it should give +inf
# This is a silly hack to get around that.
# This hack does not skew the probability distribution, because this event can't possibly win the argmax.
u[u == 1.0] = 0.999
return torch.argmax(logits - torch.log(-torch.log(u)), dim=-1)
def kl_divergence(self, logits_q: torch.Tensor, logits_p: torch.Tensor) -> torch.Tensor:
"""
Categorical distribution KL divergence calculation
KL(Q || P) = sum Q_i log (Q_i / P_i)
When talking about logits this is:
sum exp(Q_i) * (Q_i - P_i)
"""
kl = (torch.exp(logits_q) * (logits_q - logits_p)).sum(-1, keepdim=True)
# kl is per-entry, still of size self.output_shape; we need to reduce of the rest of it.
for _ in self.output_shape[:-1]:
kl = kl.sum(dim=-2) # dim=-2 because we use keepdim=True above.
return kl
class DictActionHead(nn.ModuleDict):
"""Action head with multiple sub-actions"""
def reset_parameters(self):
for subhead in self.values():
subhead.reset_parameters()
def forward(self, input_data: torch.Tensor, **kwargs) -> Any:
"""
:param kwargs: each kwarg should be a dict with keys corresponding to self.keys()
e.g. if this ModuleDict has submodules keyed by 'A', 'B', and 'C', we could call:
forward(input_data, foo={'A': True, 'C': False}, bar={'A': 7}}
Then children will be called with:
A: forward(input_data, foo=True, bar=7)
B: forward(input_data)
C: forward(input_Data, foo=False)
"""
result = {}
for head_name, subhead in self.items():
head_kwargs = {
kwarg_name: kwarg[head_name]
for kwarg_name, kwarg in kwargs.items()
if kwarg is not None and head_name in kwarg
}
result[head_name] = subhead(input_data, **head_kwargs)
return result
def logprob(self, actions: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
return sum(subhead.logprob(actions[k], logits[k]) for k, subhead in self.items())
def sample(self, logits: torch.Tensor, deterministic: bool = False) -> Any:
return {k: subhead.sample(logits[k], deterministic) for k, subhead in self.items()}
def entropy(self, logits: torch.Tensor) -> torch.Tensor:
return sum(subhead.entropy(logits[k]) for k, subhead in self.items())
def kl_divergence(self, logits_q: torch.Tensor, logits_p: torch.Tensor) -> torch.Tensor:
return sum(subhead.kl_divergence(logits_q[k], logits_p[k]) for k, subhead in self.items())
def make_action_head(ac_space: ValType, pi_out_size: int, temperature: float = 1.0):
"""Helper function to create an action head corresponding to the environment action space"""
if isinstance(ac_space, TensorType):
if isinstance(ac_space.eltype, Discrete):
return CategoricalActionHead(pi_out_size, ac_space.shape, ac_space.eltype.n, temperature=temperature)
elif isinstance(ac_space.eltype, Real):
if temperature != 1.0:
logging.warning("Non-1 temperature not implemented for DiagGaussianActionHead.")
assert len(ac_space.shape) == 1, "Nontrivial shapes not yet implemented."
return DiagGaussianActionHead(pi_out_size, ac_space.shape[0])
elif isinstance(ac_space, DictType):
return DictActionHead({k: make_action_head(v, pi_out_size, temperature) for k, v in ac_space.items()})
raise NotImplementedError(f"Action space of type {type(ac_space)} is not supported")
================================================
FILE: metrics/IDM/lib/action_mapping.py
================================================
import abc
import itertools
from collections import OrderedDict
from typing import Dict, List
import numpy as np
from gym3.types import DictType, Discrete, TensorType
from lib.actions import Buttons
class ActionMapping(abc.ABC):
"""Class that maps between the standard MC factored action space and a new one you define!
:param n_camera_bins: Need to specify this to define the original ac space for stats code
"""
# This is the default buttons groups, it can be changed for your action space
BUTTONS_GROUPS = OrderedDict(
hotbar=["none"] + [f"hotbar.{i}" for i in range(1, 10)],
fore_back=["none", "forward", "back"],
left_right=["none", "left", "right"],
sprint_sneak=["none", "sprint", "sneak"],
use=["none", "use"],
drop=["none", "drop"],
attack=["none", "attack"],
jump=["none", "jump"],
)
def __init__(self, n_camera_bins: int = 11):
assert n_camera_bins % 2 == 1, "n_camera_bins should be odd"
self.n_camera_bins = n_camera_bins
self.camera_null_bin = n_camera_bins // 2
self.stats_ac_space = DictType(
**{
"buttons": TensorType(shape=(len(Buttons.ALL),), eltype=Discrete(2)),
"camera": TensorType(shape=(2,), eltype=Discrete(n_camera_bins)),
}
)
@abc.abstractmethod
def from_factored(self, ac: Dict) -> Dict:
"""Converts a factored action (ac) to the new space
:param ac: Dictionary of actions that must have a batch dimension
"""
pass
@abc.abstractmethod
def to_factored(self, ac: Dict) -> Dict:
"""Converts an action in the new space (ac) to the factored action space.
:param ac: Dictionary of actions that must have a batch dimension
"""
pass
@abc.abstractmethod
def get_action_space_update(self):
"""Return a magym (gym3) action space. This will be used to update the env action space."""
pass
@abc.abstractmethod
def get_zero_action(self):
"""Return the zero or null action for this action space"""
pass
def factored_buttons_to_groups(self, ac_buttons: np.ndarray, button_group: List[str]) -> List[str]:
"""For a mutually exclusive group of buttons in button_group, find which option
in the group was chosen. Assumes that each button group has the option of 'none'
meaning that no button in the group was pressed.
:param ac_buttons: button actions from the factored action space. Should dims [B, len(Buttons.ALL)]
:param button_group: List of buttons in a mutually exclusive group. Each item in the
list should appear in Buttons.ALL except for the special case 'none' which means
no button in the group was pressed. e.g. ['none', 'forward', 'back']. For now
'none' must be the first element of button_group
Returns a list of length B, where each element is an item from button_group.
"""
assert ac_buttons.shape[1] == len(
Buttons.ALL
), f"There should be {len(Buttons.ALL)} buttons in the factored buttons space"
assert button_group[0] == "none", "This function only works if 'none' is in button_group"
# Actions in ac_buttons with order according to button_group
group_indices = [Buttons.ALL.index(b) for b in button_group if b != "none"]
ac_choices = ac_buttons[:, group_indices]
# Special cases for forward/back, left/right where mutual press means do neither
if "forward" in button_group and "back" in button_group:
ac_choices[np.all(ac_choices, axis=-1)] = 0
if "left" in button_group and "right" in button_group:
ac_choices[np.all(ac_choices, axis=-1)] = 0
ac_non_zero = np.where(ac_choices)
ac_choice = ["none" for _ in range(ac_buttons.shape[0])]
# Iterate over the non-zero indices so that if two buttons in a group were pressed at the same time
# we give priority to the button later in the group. E.g. if hotbar.1 and hotbar.2 are pressed during the same
# timestep, hotbar.2 is marked as pressed
for index, action in zip(ac_non_zero[0], ac_non_zero[1]):
ac_choice[index] = button_group[action + 1] # the zero'th index will mean no button pressed
return ac_choice
class IDMActionMapping(ActionMapping):
"""For IDM, but essentially this is just an identity mapping"""
def from_factored(self, ac: Dict) -> Dict:
return ac
def to_factored(self, ac: Dict) -> Dict:
return ac
def get_action_space_update(self):
"""Return a magym (gym3) action space. This will be used to update the env action space."""
return {
"buttons": TensorType(shape=(len(Buttons.ALL),), eltype=Discrete(2)),
"camera": TensorType(shape=(2,), eltype=Discrete(self.n_camera_bins)),
}
def get_zero_action(self):
raise NotImplementedError()
class CameraHierarchicalMapping(ActionMapping):
"""Buttons are joint as in ButtonsJointMapping, but now a camera on/off meta action is added into this joint space.
When this meta action is triggered, the separate camera head chooses a camera action which is also now a joint space.
:param n_camera_bins: number of camera bins in the factored space
"""
# Add camera meta action to BUTTONS_GROUPS
BUTTONS_GROUPS = ActionMapping.BUTTONS_GROUPS.copy()
BUTTONS_GROUPS["camera"] = ["none", "camera"]
BUTTONS_COMBINATIONS = list(itertools.product(*BUTTONS_GROUPS.values())) + ["inventory"]
BUTTONS_COMBINATION_TO_IDX = {comb: i for i, comb in enumerate(BUTTONS_COMBINATIONS)}
BUTTONS_IDX_TO_COMBINATION = {i: comb for i, comb in enumerate(BUTTONS_COMBINATIONS)}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.camera_groups = OrderedDict(
camera_x=[f"camera_x{i}" for i in range(self.n_camera_bins)],
camera_y=[f"camera_y{i}" for i in range(self.n_camera_bins)],
)
self.camera_combinations = list(itertools.product(*self.camera_groups.values()))
self.camera_combination_to_idx = {comb: i for i, comb in enumerate(self.camera_combinations)}
self.camera_idx_to_combination = {i: comb for i, comb in enumerate(self.camera_combinations)}
self.camera_null_idx = self.camera_combination_to_idx[
(f"camera_x{self.camera_null_bin}", f"camera_y{self.camera_null_bin}")
]
self._null_action = {
"buttons": self.BUTTONS_COMBINATION_TO_IDX[tuple("none" for _ in range(len(self.BUTTONS_GROUPS)))]
}
self._precompute_to_factored()
def _precompute_to_factored(self):
"""Precompute the joint action -> factored action matrix."""
button_dim = self.stats_ac_space["buttons"].size
self.BUTTON_IDX_TO_FACTORED = np.zeros((len(self.BUTTONS_IDX_TO_COMBINATION), button_dim), dtype=int)
self.BUTTON_IDX_TO_CAMERA_META_OFF = np.zeros((len(self.BUTTONS_IDX_TO_COMBINATION)), dtype=bool)
self.CAMERA_IDX_TO_FACTORED = np.zeros((len(self.camera_idx_to_combination), 2), dtype=int)
# Pre compute Buttons
for jnt_ac, button_comb in self.BUTTONS_IDX_TO_COMBINATION.items():
new_button_ac = np.zeros(len(Buttons.ALL), dtype="i")
if button_comb == "inventory":
new_button_ac[Buttons.ALL.index("inventory")] = 1
else:
for group_choice in button_comb[:-1]: # Last one is camera
if group_choice != "none":
new_button_ac[Buttons.ALL.index(group_choice)] = 1
if button_comb[-1] != "camera": # This means camera meta action is off
self.BUTTON_IDX_TO_CAMERA_META_OFF[jnt_ac] = True
self.BUTTON_IDX_TO_FACTORED[jnt_ac] = new_button_ac
# Pre compute camera
for jnt_ac, camera_comb in self.camera_idx_to_combination.items():
new_camera_ac = np.ones((2), dtype="i") * self.camera_null_bin
new_camera_ac[0] = self.camera_groups["camera_x"].index(camera_comb[0])
new_camera_ac[1] = self.camera_groups["camera_y"].index(camera_comb[1])
self.CAMERA_IDX_TO_FACTORED[jnt_ac] = new_camera_ac
def from_factored(self, ac: Dict) -> Dict:
"""Converts a factored action (ac) to the new space. Assumes ac has a batch dim"""
assert ac["camera"].ndim == 2, f"bad camera label, {ac['camera']}"
assert ac["buttons"].ndim == 2, f"bad buttons label, {ac['buttons']}"
# Get button choices for everything but camera
choices_by_group = OrderedDict(
(k, self.factored_buttons_to_groups(ac["buttons"], v)) for k, v in self.BUTTONS_GROUPS.items() if k != "camera"
)
# Set camera "on off" action based on whether non-null camera action was given
camera_is_null = np.all(ac["camera"] == self.camera_null_bin, axis=1)
choices_by_group["camera"] = ["none" if is_null else "camera" for is_null in camera_is_null]
new_button_ac = []
new_camera_ac = []
for i in range(ac["buttons"].shape[0]):
# Buttons
key = tuple([v[i] for v in choices_by_group.values()])
if ac["buttons"][i, Buttons.ALL.index("inventory")] == 1:
key = "inventory"
new_button_ac.append(self.BUTTONS_COMBINATION_TO_IDX[key])
# Camera -- inventory is also exclusive with camera
if key == "inventory":
key = (
f"camera_x{self.camera_null_bin}",
f"camera_y{self.camera_null_bin}",
)
else:
key = (f"camera_x{ac['camera'][i][0]}", f"camera_y{ac['camera'][i][1]}")
new_camera_ac.append(self.camera_combination_to_idx[key])
return dict(
buttons=np.array(new_button_ac)[:, None],
camera=np.array(new_camera_ac)[:, None],
)
def to_factored(self, ac: Dict) -> Dict:
"""Converts an action in the new space (ac) to the factored action space. Assumes ac has a batch dim"""
assert ac["camera"].shape[-1] == 1
assert ac["buttons"].shape[-1] == 1
new_button_ac = self.BUTTON_IDX_TO_FACTORED[np.squeeze(ac["buttons"], -1)]
camera_off = self.BUTTON_IDX_TO_CAMERA_META_OFF[np.squeeze(ac["buttons"], -1)]
new_camera_ac = self.CAMERA_IDX_TO_FACTORED[np.squeeze(ac["camera"], -1)]
new_camera_ac[camera_off] = self.camera_null_bin
return dict(buttons=new_button_ac, camera=new_camera_ac)
def get_action_space_update(self):
return {
"camera": TensorType(shape=(1,), eltype=Discrete(len(self.camera_combinations))),
"buttons": TensorType(shape=(1,), eltype=Discrete(len(self.BUTTONS_COMBINATIONS))),
}
def get_zero_action(self):
return self._null_action
================================================
FILE: metrics/IDM/lib/actions.py
================================================
import attr
# import minerl.herobraine.hero.mc as mc
import numpy as np
from lib.minecraft_util import store_args
class Buttons:
ATTACK = "attack"
BACK = "back"
FORWARD = "forward"
JUMP = "jump"
LEFT = "left"
RIGHT = "right"
SNEAK = "sneak"
SPRINT = "sprint"
USE = "use"
DROP = "drop"
INVENTORY = "inventory"
ALL = [
ATTACK,
BACK,
FORWARD,
JUMP,
LEFT,
RIGHT,
SNEAK,
SPRINT,
USE,
DROP,
INVENTORY,
] + [f"hotbar.{i}" for i in range(1, 10)]
class SyntheticButtons:
# Composite / scripted actions
CHANNEL_ATTACK = "channel-attack"
ALL = [CHANNEL_ATTACK]
class QuantizationScheme:
LINEAR = "linear"
MU_LAW = "mu_law"
@attr.s(auto_attribs=True)
class CameraQuantizer:
"""
A camera quantizer that discretizes and undiscretizes a continuous camera input with y (pitch) and x (yaw) components.
Parameters:
- camera_binsize: The size of the bins used for quantization. In case of mu-law quantization, it corresponds to the average binsize.
- camera_maxval: The maximum value of the camera action.
- quantization_scheme: The quantization scheme to use. Currently, two quantization schemes are supported:
- Linear quantization (default): Camera actions are split uniformly into discrete bins
- Mu-law quantization: Transforms the camera action using mu-law encoding (https://en.wikipedia.org/wiki/%CE%9C-law_algorithm)
followed by the same quantization scheme used by the linear scheme.
- mu: Mu is the parameter that defines the curvature of the mu-law encoding. Higher values of
mu will result in a sharper transition near zero. Below are some reference values listed
for choosing mu given a constant maxval and a desired max_precision value.
maxval = 10 | max_precision = 0.5 | μ ≈ 2.93826
maxval = 10 | max_precision = 0.4 | μ ≈ 4.80939
maxval = 10 | max_precision = 0.25 | μ ≈ 11.4887
maxval = 20 | max_precision = 0.5 | μ ≈ 2.7
maxval = 20 | max_precision = 0.4 | μ ≈ 4.39768
maxval = 20 | max_precision = 0.25 | μ ≈ 10.3194
maxval = 40 | max_precision = 0.5 | μ ≈ 2.60780
maxval = 40 | max_precision = 0.4 | μ ≈ 4.21554
maxval = 40 | max_precision = 0.25 | μ ≈ 9.81152
"""
camera_maxval: int
camera_binsize: int
quantization_scheme: str = attr.ib(
default=QuantizationScheme.LINEAR,
validator=attr.validators.in_([QuantizationScheme.LINEAR, QuantizationScheme.MU_LAW]),
)
mu: float = attr.ib(default=5)
def discretize(self, xy):
xy = np.clip(xy, -self.camera_maxval, self.camera_maxval)
if self.quantization_scheme == QuantizationScheme.MU_LAW:
xy = xy / self.camera_maxval
v_encode = np.sign(xy) * (np.log(1.0 + self.mu * np.abs(xy)) / np.log(1.0 + self.mu))
v_encode *= self.camera_maxval
xy = v_encode
# Quantize using linear scheme
return np.round((xy + self.camera_maxval) / self.camera_binsize).astype(np.int64)
def undiscretize(self, xy):
xy = xy * self.camera_binsize - self.camera_maxval
if self.quantization_scheme == QuantizationScheme.MU_LAW:
xy = xy / self.camera_maxval
v_decode = np.sign(xy) * (1.0 / self.mu) * ((1.0 + self.mu) ** np.abs(xy) - 1.0)
v_decode *= self.camera_maxval
xy = v_decode
return xy
class ActionTransformer:
"""Transforms actions between internal array and minerl env format."""
@store_args
def __init__(
self,
camera_maxval=10,
camera_binsize=2,
camera_quantization_scheme="linear",
camera_mu=5,
):
self.quantizer = CameraQuantizer(
camera_maxval=camera_maxval,
camera_binsize=camera_binsize,
quantization_scheme=camera_quantization_scheme,
mu=camera_mu,
)
def camera_zero_bin(self):
return self.camera_maxval // self.camera_binsize
def discretize_camera(self, xy):
return self.quantizer.discretize(xy)
def undiscretize_camera(self, pq):
return self.quantizer.undiscretize(pq)
def item_embed_id_to_name(self, item_id):
return mc.MINERL_ITEM_MAP[item_id]
def dict_to_numpy(self, acs):
"""
Env format to policy output format.
"""
act = {
"buttons": np.stack([acs.get(k, 0) for k in Buttons.ALL], axis=-1),
"camera": self.discretize_camera(acs["camera"]),
}
if not self.human_spaces:
act.update(
{
"synthetic_buttons": np.stack([acs[k] for k in SyntheticButtons.ALL], axis=-1),
"place": self.item_embed_name_to_id(acs["place"]),
"equip": self.item_embed_name_to_id(acs["equip"]),
"craft": self.item_embed_name_to_id(acs["craft"]),
}
)
return act
def numpy_to_dict(self, acs):
"""
Numpy policy output to env-compatible format.
"""
assert acs["buttons"].shape[-1] == len(
Buttons.ALL
), f"Mismatched actions: {acs}; expected {len(Buttons.ALL)}:\n( {Buttons.ALL})"
out = {name: acs["buttons"][..., i] for (i, name) in enumerate(Buttons.ALL)}
out["camera"] = self.undiscretize_camera(acs["camera"])
return out
def policy2env(self, acs):
acs = self.numpy_to_dict(acs)
return acs
def env2policy(self, acs):
nbatch = acs["camera"].shape[0]
dummy = np.zeros((nbatch,))
out = {
"camera": self.discretize_camera(acs["camera"]),
"buttons": np.stack([acs.get(k, dummy) for k in Buttons.ALL], axis=-1),
}
return out
================================================
FILE: metrics/IDM/lib/impala_cnn.py
================================================
import math
from copy import deepcopy
from typing import Dict, List, Optional
from torch import nn
from torch.nn import functional as F
from lib import misc
from lib import torch_util as tu
from lib.util import FanInInitReLULayer
class CnnBasicBlock(nn.Module):
"""
Residual basic block, as in ImpalaCNN. Preserves channel number and shape
:param inchan: number of input channels
:param init_scale: weight init scale multiplier
"""
def __init__(
self,
inchan: int,
init_scale: float = 1,
log_scope="",
init_norm_kwargs: Dict = {},
**kwargs,
):
super().__init__()
self.inchan = inchan
s = math.sqrt(init_scale)
self.conv0 = FanInInitReLULayer(
self.inchan,
self.inchan,
kernel_size=3,
padding=1,
init_scale=s,
log_scope=f"{log_scope}/conv0",
**init_norm_kwargs,
)
self.conv1 = FanInInitReLULayer(
self.inchan,
self.inchan,
kernel_size=3,
padding=1,
init_scale=s,
log_scope=f"{log_scope}/conv1",
**init_norm_kwargs,
)
def forward(self, x):
x = x + self.conv1(self.conv0(x))
return x
class CnnDownStack(nn.Module):
"""
Downsampling stack from Impala CNN.
:param inchan: number of input channels
:param nblock: number of residual blocks after downsampling
:param outchan: number of output channels
:param init_scale: weight init scale multiplier
:param pool: if true, downsample with max pool
:param post_pool_groups: if not None, normalize with group norm with this many groups
:param kwargs: remaining kwargs are passed into the blocks and layers
"""
name = "Impala_CnnDownStack"
def __init__(
self,
inchan: int,
nblock: int,
outchan: int,
init_scale: float = 1,
pool: bool = True,
post_pool_groups: Optional[int] = None,
log_scope: str = "",
init_norm_kwargs: Dict = {},
first_conv_norm=False,
**kwargs,
):
super().__init__()
self.inchan = inchan
self.outchan = outchan
self.pool = pool
first_conv_init_kwargs = deepcopy(init_norm_kwargs)
if not first_conv_norm:
first_conv_init_kwargs["group_norm_groups"] = None
first_conv_init_kwargs["batch_norm"] = False
self.firstconv = FanInInitReLULayer(
inchan,
outchan,
kernel_size=3,
padding=1,
log_scope=f"{log_scope}/firstconv",
**first_conv_init_kwargs,
)
self.post_pool_groups = post_pool_groups
if post_pool_groups is not None:
self.n = nn.GroupNorm(post_pool_groups, outchan)
self.blocks = nn.ModuleList(
[
CnnBasicBlock(
outchan,
init_scale=init_scale / math.sqrt(nblock),
log_scope=f"{log_scope}/block{i}",
init_norm_kwargs=init_norm_kwargs,
**kwargs,
)
for i in range(nblock)
]
)
def forward(self, x):
x = self.firstconv(x)
if self.pool:
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
if self.post_pool_groups is not None:
x = self.n(x)
x = tu.sequential(self.blocks, x, diag_name=self.name)
return x
def output_shape(self, inshape):
c, h, w = inshape
assert c == self.inchan
if self.pool:
return (self.outchan, (h + 1) // 2, (w + 1) // 2)
else:
return (self.outchan, h, w)
class ImpalaCNN(nn.Module):
"""
:param inshape: input image shape (height, width, channels)
:param chans: number of residual downsample stacks. Each element is the number of
filters per convolution in the stack
:param outsize: output hidden size
:param nblock: number of residual blocks per stack. Each block has 2 convs and a residual
:param init_norm_kwargs: arguments to be passed to convolutional layers. Options can be found
in ypt.model.util:FanInInitReLULayer
:param dense_init_norm_kwargs: arguments to be passed to convolutional layers. Options can be found
in ypt.model.util:FanInInitReLULayer
:param kwargs: remaining kwargs are passed into the CnnDownStacks
"""
name = "ImpalaCNN"
def __init__(
self,
inshape: List[int],
chans: List[int],
outsize: int,
nblock: int,
init_norm_kwargs: Dict = {},
dense_init_norm_kwargs: Dict = {},
first_conv_norm=False,
**kwargs,
):
super().__init__()
h, w, c = inshape
curshape = (c, h, w)
self.stacks = nn.ModuleList()
for i, outchan in enumerate(chans):
stack = CnnDownStack(
curshape[0],
nblock=nblock,
outchan=outchan,
init_scale=math.sqrt(len(chans)),
log_scope=f"downstack{i}",
init_norm_kwargs=init_norm_kwargs,
first_conv_norm=first_conv_norm if i == 0 else True,
**kwargs,
)
self.stacks.append(stack)
curshape = stack.output_shape(curshape)
self.dense = FanInInitReLULayer(
misc.intprod(curshape),
outsize,
layer_type="linear",
log_scope="imapala_final_dense",
init_scale=1.4,
**dense_init_norm_kwargs,
)
self.outsize = outsize
def forward(self, x):
b, t = x.shape[:-3]
x = x.reshape(b * t, *x.shape[-3:])
x = misc.transpose(x, "bhwc", "bchw")
x = tu.sequential(self.stacks, x, diag_name=self.name)
x = x.reshape(b, t, *x.shape[1:])
x = tu.flatten_image(x)
x = self.dense(x)
return x
================================================
FILE: metrics/IDM/lib/masked_attention.py
================================================
import functools
import torch as th
from torch import nn
import lib.xf as xf
from lib.minecraft_util import store_args
from lib.tree_util import tree_map
@functools.lru_cache()
def get_band_diagonal_mask(t: int, T: int, maxlen: int, batchsize: int, device: th.device) -> th.Tensor:
"""Returns a band diagonal mask which is causal (upper triangle is masked)
and such that any frame can only view up to maxlen total past frames
including the current frame.
Example Masks: Here 0 means that frame is masked and we mask it by adding a huge number to the attention logits (see orc.xf)
t = 3, T = 3, maxlen = 3
T
t 1 0 0 | mask out T > t
1 1 0 |
1 1 1 |
t = 3, T = 6, maxlen = 3
t 0 1 1 1 0 0 | mask out T > t
0 0 1 1 1 0 |
0 0 0 1 1 1 |
Args:
t: number of rows (presumably number of frames recieving gradient)
T: number of cols (presumably t + past context that isn't being gradient updated)
maxlen: maximum number of frames (including current frame) any frame can attend to
batchsize: number of masks to return
device: torch device to place mask on
Returns:
Boolean mask of shape (batchsize, t, T)
"""
m = th.ones(t, T, dtype=bool)
m.tril_(T - t) # Mask out upper triangle
if maxlen is not None and maxlen < T: # Mask out lower triangle
m.triu_(T - t - maxlen + 1)
m_btT = m[None].repeat_interleave(batchsize, dim=0)
m_btT = m_btT.to(device=device)
return m_btT
def get_mask(first_b11: th.Tensor, state_mask: th.Tensor, t: int, T: int, maxlen: int, heads: int, device) -> th.Tensor:
"""Returns a band diagonal mask that respects masking past states (columns 0:T-t inclusive)
if first_b11 is True. See get_band_diagonal_mask for how the base mask is computed.
This function takes that mask and first zeros out any past context if first_b11 is True.
Say our context is in chunks of length t (so here T = 4t). We see that in the second batch we recieved first=True
context t t t t
first F T F F
Now, given this the mask should mask out anything prior to T < t; however since we don't have access to the past first_b11's
we need to keep a state of the mask at those past timesteps. This is what state_mask is.
In particular state_mask is a [b, t, T - t] mask matrix that contains the mask for the past T - t frames.
Args: (See get_band_diagonal_mask for remaining args)
first_b11: boolean tensor with shape [batchsize, 1, 1] indicating if the first timestep for each batch element had first=True
state_mask: mask tensor of shape [b, t, T - t]
t: number of mask rows (presumably number of frames for which we take gradient)
T: number of mask columns (t + the number of past frames we keep in context)
maxlen: actual context length
heads: number of attention heads
device: torch device
Returns:
m_btT: Boolean mask of shape (batchsize * heads, t, T)
state_mask: updated state_mask
"""
b = first_b11.shape[0]
if state_mask is None:
state_mask = th.zeros((b, 1, T - t), dtype=bool, device=device)
m_btT = get_band_diagonal_mask(t, T, maxlen, b, device).clone() # Should be shape B, t, T
not_first = ~first_b11.to(device=device)
m_btT[:, :, :-t] &= not_first # Zero out anything in the past if first is true
m_btT[:, :, :-t] &= state_mask
m_bhtT = m_btT[:, None].repeat_interleave(heads, dim=1)
m_btT = m_bhtT.reshape((b * heads), t, T)
# Update state_mask such that it reflects the most recent first
state_mask = th.cat(
[
state_mask[:, :, t:] & not_first,
th.ones((b, 1, min(t, T - t)), dtype=bool, device=device),
],
dim=-1,
)
return m_btT, state_mask
class MaskedAttention(nn.Module):
"""
Transformer self-attention layer that removes frames from previous episodes from the hidden state under certain constraints.
The constraints are:
- The "first" flag can only be true for the first timestep of each batch. An assert will fire if other timesteps have first = True.
input_size: The dimension of the input (which also happens to be the size of the output)
memory_size: The number of frames to keep in the inner state. Note that when attending, we will be able to attend
to both the frames in the inner state (which presumably won't have gradients anymore) and the frames
in the batch. "mask" for some additional considerations on this.
heads: The number of attention heads to use. Note that we will split the input into this number of heads, so
input_size needs to be divisible by heads.
timesteps: number of timesteps with which we'll be taking gradient
mask: Can be "none" or "clipped_causal". "clipped_causal" is a normal causal mask but solves the following minor problem:
if you have a state of length 128 and a batch of 128 frames, then the first frame of your batch will be able to
attend to 128 previous frames, but the last one will be able to attend to 255 previous frames. In this example,
"clipped_causal" will make it so that the last frame can only attend to 128 previous frames, so that there is no
bias coming from the position in the batch. None simply allows you to attend to any frame in the state + batch,
which means you can also attend to future frames.
"""
@store_args
def __init__(
self,
input_size,
memory_size: int,
heads: int,
timesteps: int,
mask: str = "clipped_causal",
init_scale=1,
norm="none",
log_scope="sa",
use_muP_factor=False,
):
super().__init__()
assert mask in {"none", "clipped_causal"}
assert memory_size >= 0
self.maxlen = memory_size - timesteps
if mask == "none":
mask = None
self.orc_attn = xf.All2All(heads, self.maxlen, mask=mask is not None)
self.orc_block = xf.SelfAttentionLayer(
input_size,
self.orc_attn,
scale=init_scale,
relattn=True,
cache_keep_len=self.maxlen,
norm=norm,
log_scope=log_scope,
use_muP_factor=use_muP_factor,
)
def initial_state(self, batchsize: int, device=None):
"""Return the initial state mask (None) and the initial state of the transformer (zerod out keys and queries)"""
state = self.orc_block.initial_state(batchsize, initial_T=self.maxlen)
state_mask = None
if device is not None:
state = tree_map(lambda x: x.to(device), state)
return state_mask, state
def forward(self, input_bte, first_bt, state):
"""Forward propagation of a single layer"""
state_mask, xf_state = state
t = first_bt.shape[1]
if self.mask == "clipped_causal":
new_mask, state_mask = get_mask(
first_b11=first_bt[:, [[0]]],
state_mask=state_mask,
t=t,
T=t + self.maxlen,
maxlen=self.maxlen,
heads=self.heads,
device=input_bte.device,
)
self.orc_block.attn.mask = new_mask
output, xf_state = self.orc_block(input_bte, xf_state)
return output, (state_mask, xf_state)
def get_log_keys(self):
# These are logged in xf.SelfAttentionLayer
return [f"activation_{stat}/{self.log_scope}/{k}" for k in ["K", "Q", "V", "A", "Aproj"] for stat in ["mean", "std"]]
================================================
FILE: metrics/IDM/lib/minecraft_util.py
================================================
import functools
import inspect
from typing import Optional, Tuple
import numpy as np
import torch
from lib.action_head import (CategoricalActionHead, DiagGaussianActionHead,
DictActionHead)
def store_args(method):
"""Stores provided method args as instance attributes."""
argspec = inspect.getfullargspec(method)
defaults = {}
if argspec.defaults is not None:
defaults = dict(zip(argspec.args[-len(argspec.defaults) :], argspec.defaults))
if argspec.kwonlydefaults is not None:
defaults.update(argspec.kwonlydefaults)
arg_names = argspec.args[1:]
@functools.wraps(method)
def wrapper(*positional_args, **keyword_args):
self = positional_args[0]
# Get default arg values
args = defaults.copy()
# Add provided arg values
for name, value in zip(arg_names, positional_args[1:]):
args[name] = value
args.update(keyword_args)
self.__dict__.update(args)
return method(*positional_args, **keyword_args)
return wrapper
def get_norm_entropy_from_cat_head(module, name, masks, logits):
# Note that the mask has already been applied to the logits at this point
entropy = -torch.sum(torch.exp(logits) * logits, dim=-1)
if name in masks:
n = torch.sum(masks[name], dim=-1, dtype=torch.float)
norm_entropy = entropy / torch.log(n)
# When the mask only allows one option the normalized entropy makes no sense
# as it is basically both maximal (the distribution is as uniform as it can be)
# and minimal (there is no variance at all).
# A such, we ignore them for purpose of calculating entropy.
zero = torch.zeros_like(norm_entropy)
norm_entropy = torch.where(n.eq(1.0), zero, norm_entropy)
count = n.not_equal(1.0).int()
else:
n = torch.tensor(logits.shape[-1], dtype=torch.float)
norm_entropy = entropy / torch.log(n)
count = torch.ones_like(norm_entropy, dtype=torch.int)
# entropy is per-entry, still of size self.output_shape[:-1]; we need to reduce of the rest of it.
for _ in module.output_shape[:-1]:
norm_entropy = norm_entropy.sum(dim=-1)
count = count.sum(dim=-1)
return norm_entropy, count
def get_norm_cat_entropy(module, masks, logits, template) -> Tuple[torch.Tensor, torch.Tensor]:
entropy_sum = torch.zeros_like(template, dtype=torch.float)
counts = torch.zeros_like(template, dtype=torch.int)
for k, subhead in module.items():
if isinstance(subhead, DictActionHead):
entropy, count = get_norm_cat_entropy(subhead, masks, logits[k], template)
elif isinstance(subhead, CategoricalActionHead):
entropy, count = get_norm_entropy_from_cat_head(subhead, k, masks, logits[k])
else:
continue
entropy_sum += entropy
counts += count
return entropy_sum, counts
def get_diag_guassian_entropy(module, logits, template) -> Optional[torch.Tensor]:
entropy_sum = torch.zeros_like(template, dtype=torch.float)
count = torch.zeros(1, device=template.device, dtype=torch.int)
for k, subhead in module.items():
if isinstance(subhead, DictActionHead):
entropy_sum += get_diag_guassian_entropy(subhead, logits[k], template)
elif isinstance(subhead, DiagGaussianActionHead):
entropy_sum += module.entropy(logits)
else:
continue
count += 1
return entropy_sum / count
================================================
FILE: metrics/IDM/lib/misc.py
================================================
import numpy as np
import torch as th
def intprod(xs):
"""
Product of a sequence of integers
"""
out = 1
for x in xs:
out *= x
return out
def safezip(*args):
"""
Check that lengths of sequences are the same, then zip them
"""
args = [list(a) for a in args]
n = len(args[0])
for arg in args[1:]:
assert len(arg) == n, f"length mismatch: {list(map(len, args))}"
return list(zip(*args))
def transpose(x, before, after):
"""
Usage: x_bca = transpose(x_abc, 'abc', 'bca')
"""
assert sorted(before) == sorted(after), f"cannot transpose {before} to {after}"
assert x.ndim == len(
before
), f"before spec '{before}' has length {len(before)} but x has {x.ndim} dimensions: {tuple(x.shape)}"
return x.permute(tuple(before.index(i) for i in after))
def transpose_undo(x, before, after, *, undo=None):
"""
Usage:
x_bca, undo = transpose_undo(x_abc, 'abc', 'bca')
x_bca = fully_connected_layer(x_bca)
x_abc = undo(x_bca)
"""
return (
transpose(x, before, after),
compose_undo(undo, lambda x: transpose(x, before=after, after=before)),
)
def compose_undo(u1, u2):
assert u2 is not None
if u1 is None:
return u2
def u(x):
x = u2(x)
x = u1(x)
return x
return u
NO_BIND = "__nobind"
def _parse_reshape_str(s, kind):
assert kind in ("before", "after")
result = []
n_underscores = 0
for i, part in enumerate(s.split(",")):
part = part.strip()
if part == "?" and kind == "before":
result.append([f"__{i}"])
elif part == "_":
result.append([f"{NO_BIND}_{n_underscores}"])
n_underscores += 1
else:
result.append([term.strip() for term in part.split("*")])
return result
def _infer_part(part, concrete_dim, known, index, full_shape):
if type(part) is int:
return part
assert isinstance(part, list), part
lits = []
syms = []
for term in part:
if type(term) is int:
lits.append(term)
elif type(term) is str:
syms.append(term)
else:
raise TypeError(f"got {type(term)} but expected int or str")
int_part = 1
for x in lits:
int_part *= x
if len(syms) == 0:
return int_part
elif len(syms) == 1 and concrete_dim is not None:
assert concrete_dim % int_part == 0, f"{concrete_dim} % {int_part} != 0 (at index {index}, full shape is {full_shape})"
v = concrete_dim // int_part
if syms[0] in known:
assert (
known[syms[0]] == v
), f"known value for {syms[0]} is {known[syms[0]]} but found value {v} at index {index} (full shape is {full_shape})"
else:
known[syms[0]] = v
return concrete_dim
else:
for i in range(len(syms)):
if syms[i] in known:
syms[i] = known[syms[i]]
else:
try:
syms[i] = int(syms[i])
except ValueError:
pass
return lits + syms
def _infer_step(args):
known, desc, shape = args
new_known = known.copy()
new_desc = desc.copy()
for i in range(len(desc)):
if shape is None:
concrete_dim = None
else:
concrete_dim = shape[i]
new_desc[i] = _infer_part(part=desc[i], concrete_dim=concrete_dim, known=new_known, index=i, full_shape=shape)
return new_known, new_desc, shape
def _infer(known, desc, shape):
if shape is not None:
assert len(desc) == len(shape), f"desc has length {len(desc)} but shape has length {len(shape)} (shape={shape})"
known, desc, shape = fixed_point(_infer_step, (known, desc, shape))
return desc, known
def fixed_point(f, x, eq=None):
if eq is None:
eq = lambda a, b: a == b
while True:
new_x = f(x)
if eq(x, new_x):
return x
else:
x = new_x
def _infer_question_mark(x, total_product):
try:
question_mark_index = x.index(["?"])
except ValueError:
return x
observed_product = 1
for i in range(len(x)):
if i != question_mark_index:
assert type(x[i]) is int, f"when there is a question mark, there can be no other unknown values (full list: {x})"
observed_product *= x[i]
assert (
observed_product and total_product % observed_product == 0
), f"{total_product} is not divisible by {observed_product}"
value = total_product // observed_product
x = x.copy()
x[question_mark_index] = value
return x
def _ground(x, known, infer_question_mark_with=None):
x, known = _infer(known=known, desc=x, shape=None)
if infer_question_mark_with:
x = _infer_question_mark(x, infer_question_mark_with)
for part in x:
assert type(part) is int, f"cannot infer value of {part}"
return x
def _handle_ellipsis(x, before, after):
ell = ["..."]
try:
i = before.index(ell)
l = len(x.shape) - len(before) + 1
ellipsis_value = x.shape[i : i + l]
ellipsis_value = list(ellipsis_value)
before = before[:i] + ellipsis_value + before[i + 1 :]
except ValueError:
pass
try:
i = after.index(ell)
after = after[:i] + ellipsis_value + after[i + 1 :]
except ValueError:
pass
except UnboundLocalError as e:
raise ValueError("there cannot be an ellipsis in 'after' unless there is an ellipsis in 'before'") from e
return before, after
def reshape_undo(inp, before, after, *, undo=None, known=None, **kwargs):
"""
Usage:
x_Bhwse, undo = reshape_undo(
x_bthwe,
'b, t, ..., stride*e',
'b*t, ..., stride, e',
stride=7
)
x_Bhwse = do_some_stuff(x_Bhwse)
x_bthwe = undo(x_Bhwse)
It's necessary to pass known values as keywords only
when they can't be inferred from the shape.
(Eg. in the above example we needed to pass
stride but not b, t, or e, since those can be determined from
inp.shape once stride is known.)
"""
if known:
known = {**kwargs, **known}
else:
known = kwargs
assert type(before) is type(after), f"{type(before)} != {type(after)}"
assert isinstance(inp, (th.Tensor, np.ndarray)), f"require tensor or ndarray but got {type(inp)}"
assert isinstance(before, (str, list)), f"require str or list but got {type(before)}"
if isinstance(before, str):
before = _parse_reshape_str(before, "before")
after = _parse_reshape_str(after, "after")
before, after = _handle_ellipsis(inp, before, after)
before_saved, after_saved = before, after
before, known = _infer(known=known, desc=before, shape=inp.shape)
before = _ground(before, known, product(inp.shape))
after = _ground(after, known, product(inp.shape))
known = {k: v for k, v in known.items() if not k.startswith(NO_BIND)}
assert tuple(inp.shape) == tuple(before), f"expected shape {before} but got shape {inp.shape}"
assert product(inp.shape) == product(
after
), f"cannot reshape {inp.shape} to {after} because the number of elements does not match"
return (
inp.reshape(after),
compose_undo(undo, lambda inp: reshape(inp, after_saved, before_saved, known=known)),
)
def reshape(*args, **kwargs):
"""
Please see the documenation for reshape_undo.
"""
x, _ = reshape_undo(*args, **kwargs)
return x
def product(xs, one=1):
result = one
for x in xs:
result = result * x
return result
def exact_div(a, b):
assert a % b == 0, f"{a} is not divisible by {b}"
return a // b
================================================
FILE: metrics/IDM/lib/mlp.py
================================================
import torch as th
from torch import nn
from lib import misc
from lib import torch_util as tu
class MLP(nn.Module):
def __init__(self, insize, nhidlayer, outsize, hidsize, hidactiv, dtype=th.float32):
super().__init__()
self.insize = insize
self.nhidlayer = nhidlayer
self.outsize = outsize
in_sizes = [insize] + [hidsize] * nhidlayer
out_sizes = [hidsize] * nhidlayer + [outsize]
self.layers = nn.ModuleList(
[tu.NormedLinear(insize, outsize, dtype=dtype) for (insize, outsize) in misc.safezip(in_sizes, out_sizes)]
)
self.hidactiv = hidactiv
def forward(self, x):
*hidlayers, finallayer = self.layers
for layer in hidlayers:
x = layer(x)
x = self.hidactiv(x)
x = finallayer(x)
return x
@property
def output_shape(self):
return (self.outsize,)
================================================
FILE: metrics/IDM/lib/normalize_ewma.py
================================================
import numpy as np
import torch
import torch.nn as nn
class NormalizeEwma(nn.Module):
"""Normalize a vector of observations - across the first norm_axes dimensions"""
def __init__(self, input_shape, norm_axes=2, beta=0.99999, per_element_update=False, epsilon=1e-5):
super().__init__()
self.input_shape = input_shape
self.norm_axes = norm_axes
self.epsilon = epsilon
self.beta = beta
self.per_element_update = per_element_update
self.running_mean = nn.Parameter(torch.zeros(input_shape, dtype=torch.float), requires_grad=False)
self.running_mean_sq = nn.Parameter(torch.zeros(input_shape, dtype=torch.float), requires_grad=False)
self.debiasing_term = nn.Parameter(torch.tensor(0.0, dtype=torch.float), requires_grad=False)
def reset_parameters(self):
self.running_mean.zero_()
self.running_mean_sq.zero_()
self.debiasing_term.zero_()
def running_mean_var(self):
debiased_mean = self.running_mean / self.debiasing_term.clamp(min=self.epsilon)
debiased_mean_sq = self.running_mean_sq / self.debiasing_term.clamp(min=self.epsilon)
debiased_var = (debiased_mean_sq - debiased_mean ** 2).clamp(min=1e-2)
return debiased_mean, debiased_var
def forward(self, input_vector):
# Make sure input is float32
input_vector = input_vector.to(torch.float)
if self.training:
# Detach input before adding it to running means to avoid backpropping through it on
# subsequent batches.
detached_input = input_vector.detach()
batch_mean = detached_input.mean(dim=tuple(range(self.norm_axes)))
batch_sq_mean = (detached_input ** 2).mean(dim=tuple(range(self.norm_axes)))
if self.per_element_update:
batch_size = np.prod(detached_input.size()[: self.norm_axes])
weight = self.beta ** batch_size
else:
weight = self.beta
self.running_mean.mul_(weight).add_(batch_mean * (1.0 - weight))
self.running_mean_sq.mul_(weight).add_(batch_sq_mean * (1.0 - weight))
self.debiasing_term.mul_(weight).add_(1.0 * (1.0 - weight))
mean, var = self.running_mean_var()
return (input_vector - mean[(None,) * self.norm_axes]) / torch.sqrt(var)[(None,) * self.norm_axes]
def denormalize(self, input_vector):
"""Transform normalized data back into original distribution"""
mean, var = self.running_mean_var()
return input_vector * torch.sqrt(var)[(None,) * self.norm_axes] + mean[(None,) * self.norm_axes]
================================================
FILE: metrics/IDM/lib/policy.py
================================================
from copy import deepcopy
from email import policy
from typing import Dict, Optional
import numpy as np
import torch as th
from gym3.types import DictType
from torch import nn
from torch.nn import functional as F
from lib.action_head import make_action_head
from lib.action_mapping import CameraHierarchicalMapping
from lib.impala_cnn import ImpalaCNN
from lib.normalize_ewma import NormalizeEwma
from lib.scaled_mse_head import ScaledMSEHead
from lib.tree_util import tree_map
from lib.util import FanInInitReLULayer, ResidualRecurrentBlocks
from lib.misc import transpose
class ImgPreprocessing(nn.Module):
"""Normalize incoming images.
:param img_statistics: remote path to npz file with a mean and std image. If specified
normalize images using this.
:param scale_img: If true and img_statistics not specified, scale incoming images by 1/255.
"""
def __init__(self, img_statistics: Optional[str] = None, scale_img: bool = True):
super().__init__()
self.img_mean = None
if img_statistics is not None:
img_statistics = dict(**np.load(img_statistics))
self.img_mean = nn.Parameter(th.Tensor(img_statistics["mean"]), requires_grad=False)
self.img_std = nn.Parameter(th.Tensor(img_statistics["std"]), requires_grad=False)
else:
self.ob_scale = 255.0 if scale_img else 1.0
def forward(self, img):
x = img.to(dtype=th.float32)
if self.img_mean is not None:
x = (x - self.img_mean) / self.img_std
else:
x = x / self.ob_scale
return x
class ImgObsProcess(nn.Module):
"""ImpalaCNN followed by a linear layer.
:param cnn_outsize: impala output dimension
:param output_size: output size of the linear layer.
:param dense_init_norm_kwargs: kwargs for linear FanInInitReLULayer
:param init_norm_kwargs: kwargs for 2d and 3d conv FanInInitReLULayer
"""
def __init__(
self,
cnn_outsize: int,
output_size: int,
dense_init_norm_kwargs: Dict = {},
init_norm_kwargs: Dict = {},
**kwargs,
):
super().__init__()
self.cnn = ImpalaCNN(
outsize=cnn_outsize,
init_norm_kwargs=init_norm_kwargs,
dense_init_norm_kwargs=dense_init_norm_kwargs,
**kwargs,
)
self.linear = FanInInitReLULayer(
cnn_outsize,
output_size,
layer_type="linear",
**dense_init_norm_kwargs,
)
def forward(self, img):
return self.linear(self.cnn(img))
class MinecraftPolicy(nn.Module):
"""
:param recurrence_type:
None - No recurrence, adds no extra layers
lstm - (Depreciated). Singular LSTM
multi_layer_lstm - Multi-layer LSTM. Uses n_recurrence_layers to determine number of consecututive LSTMs
Does NOT support ragged batching
multi_masked_lstm - Multi-layer LSTM that supports ragged batching via the first vector. This model is slower
Uses n_recurrence_layers to determine number of consecututive LSTMs
transformer - Dense transformer
:param init_norm_kwargs: kwargs for all FanInInitReLULayers.
"""
def __init__(
self,
recurrence_type="lstm",
impala_width=1,
impala_chans=(16, 32, 32),
obs_processing_width=256,
hidsize=512,
single_output=False, # True if we don't need separate outputs for action/value outputs
img_shape=None,
scale_input_img=True,
only_img_input=False,
init_norm_kwargs={},
impala_kwargs={},
# Unused argument assumed by forc.
input_shape=None, # pylint: disable=unused-argument
active_reward_monitors=None,
img_statistics=None,
first_conv_norm=False,
diff_mlp_embedding=False,
attention_mask_style="clipped_causal",
attention_heads=8,
attention_memory_size=2048,
use_pointwise_layer=True,
pointwise_ratio=4,
pointwise_use_activation=False,
n_recurrence_layers=1,
recurrence_is_residual=True,
timesteps=None,
use_pre_lstm_ln=True, # Not needed for transformer
**unused_kwargs,
):
super().__init__()
assert recurrence_type in [
"multi_layer_lstm",
"multi_layer_bilstm",
"multi_masked_lstm",
"transformer",
"none",
]
active_reward_monitors = active_reward_monitors or {}
self.single_output = single_output
chans = tuple(int(impala_width * c) for c in impala_chans)
self.hidsize = hidsize
# Dense init kwargs replaces batchnorm/groupnorm with layernorm
self.init_norm_kwargs = init_norm_kwargs
self.dense_init_norm_kwargs = deepcopy(init_norm_kwargs)
if self.dense_init_norm_kwargs.get("group_norm_groups", None) is not None:
self.dense_init_norm_kwargs.pop("group_norm_groups", None)
self.dense_init_norm_kwargs["layer_norm"] = True
if self.dense_init_norm_kwargs.get("batch_norm", False):
self.dense_init_norm_kwargs.pop("batch_norm", False)
self.dense_init_norm_kwargs["layer_norm"] = True
# Setup inputs
self.img_preprocess = ImgPreprocessing(img_statistics=img_statistics, scale_img=scale_input_img)
self.img_process = ImgObsProcess(
cnn_outsize=256,
output_size=hidsize,
inshape=img_shape,
chans=chans,
nblock=2,
dense_init_norm_kwargs=self.dense_init_norm_kwargs,
init_norm_kwargs=init_norm_kwargs,
first_conv_norm=first_conv_norm,
**impala_kwargs,
)
self.pre_lstm_ln = nn.LayerNorm(hidsize) if use_pre_lstm_ln else None
self.diff_obs_process = None
self.recurrence_type = recurrence_type
self.recurrent_layer = None
self.recurrent_layer = ResidualRecurrentBlocks(
hidsize=hidsize,
timesteps=timesteps,
recurrence_type=recurrence_type,
is_residual=recurrence_is_residual,
use_pointwise_layer=use_pointwise_layer,
pointwise_ratio=pointwise_ratio,
pointwise_use_activation=pointwise_use_activation,
attention_mask_style=attention_mask_style,
attention_heads=attention_heads,
attention_memory_size=attention_memory_size,
n_block=n_recurrence_layers,
)
self.lastlayer = FanInInitReLULayer(hidsize, hidsize, layer_type="linear", **self.dense_init_norm_kwargs)
self.final_ln = th.nn.LayerNorm(hidsize)
def output_latent_size(self):
return self.hidsize
def forward(self, ob, state_in, context):
first = context["first"]
x = self.img_preprocess(ob["img"])
x = self.img_process(x)
if self.diff_obs_process:
processed_obs = self.diff_obs_process(ob["diff_goal"])
x = processed_obs + x
if self.pre_lstm_ln is not None:
x = self.pre_lstm_ln(x)
if self.recurrent_layer is not None:
x, state_out = self.recurrent_layer(x, first, state_in)
else:
state_out = state_in
x = F.relu(x, inplace=False)
x = self.lastlayer(x)
x = self.final_ln(x)
pi_latent = vf_latent = x
if self.single_output:
return pi_latent, state_out
return (pi_latent, vf_latent), state_out
def initial_state(self, batchsize):
if self.recurrent_layer:
return self.recurrent_layer.initial_state(batchsize)
else:
return None
class MinecraftAgentPolicy(nn.Module):
def __init__(self, action_space, policy_kwargs, pi_head_kwargs):
super().__init__()
self.net = MinecraftPolicy(**policy_kwargs)
self.action_space = action_space
self.value_head = self.make_value_head(self.net.output_latent_size())
self.pi_head = self.make_action_head(self.net.output_latent_size(), **pi_head_kwargs)
def make_value_head(self, v_out_size: int, norm_type: str = "ewma", norm_kwargs: Optional[Dict] = None):
return ScaledMSEHead(v_out_size, 1, norm_type=norm_type, norm_kwargs=norm_kwargs)
def make_action_head(self, pi_out_size: int, **pi_head_opts):
return make_action_head(self.action_space, pi_out_size, **pi_head_opts)
def initial_state(self, batch_size: int):
return self.net.initial_state(batch_size)
def reset_parameters(self):
super().reset_parameters()
self.net.reset_parameters()
self.pi_head.reset_parameters()
self.value_head.reset_parameters()
def forward(self, obs, first: th.Tensor, state_in):
if isinstance(obs, dict):
# We don't want to mutate the obs input.
obs = obs.copy()
# If special "mask" key is in obs,
# It's for masking the logits.
# We take it out (the network doesn't need it)
mask = obs.pop("mask", None)
else:
mask = None
(pi_h, v_h), state_out = self.net(obs, state_in, context={"first": first})
pi_logits = self.pi_head(pi_h, mask=mask)
vpred = self.value_head(v_h)
return (pi_logits, vpred, None), state_out
def get_logprob_of_action(self, pd, action):
"""
Get logprob of taking action `action` given probability distribution
(see `get_gradient_for_action` to get this distribution)
"""
ac = tree_map(lambda x: x.unsqueeze(1), action)
log_prob = self.pi_head.logprob(ac, pd)
assert not th.isnan(log_prob).any()
return log_prob[:, 0]
def get_kl_of_action_dists(self, pd1, pd2):
"""
Get the KL divergence between two action probability distributions
"""
return self.pi_head.kl_divergence(pd1, pd2)
def get_output_for_observation(self, obs, state_in, first):
"""
Return gradient-enabled outputs for given observation.
Use `get_logprob_of_action` to get log probability of action
with the given probability distribution.
Returns:
- probability distribution given observation
- value prediction for given observation
- new state
"""
# We need to add a fictitious time dimension everywhere
obs = tree_map(lambda x: x.unsqueeze(1), obs)
first = first.unsqueeze(1)
(pd, vpred, _), state_out = self(obs=obs, first=first, state_in=state_in)
return pd, self.value_head.denormalize(vpred)[:, 0], state_out
@th.no_grad()
def act(self, obs, first, state_in, stochastic: bool = True, taken_action=None, return_pd=False):
# We need to add a fictitious time dimension everywhere
obs = tree_map(lambda x: x.unsqueeze(1), obs)
first = first.unsqueeze(1)
(pd, vpred, _), state_out = self(obs=obs, first=first, state_in=state_in)
if taken_action is None:
ac = self.pi_head.sample(pd, deterministic=not stochastic)
else:
ac = tree_map(lambda x: x.unsqueeze(1), taken_action)
log_prob = self.pi_head.logprob(ac, pd)
assert not th.isnan(log_prob).any()
# After unsqueezing, squeeze back to remove fictitious time dimension
result = {"log_prob": log_prob[:, 0], "vpred": self.value_head.denormalize(vpred)[:, 0]}
if return_pd:
result["pd"] = tree_map(lambda x: x[:, 0], pd)
ac = tree_map(lambda x: x[:, 0], ac)
return ac, state_out, result
@th.no_grad()
def v(self, obs, first, state_in):
"""Predict value for a given mdp observation"""
obs = tree_map(lambda x: x.unsqueeze(1), obs)
first = first.unsqueeze(1)
(pd, vpred, _), state_out = self(obs=obs, first=first, state_in=state_in)
# After unsqueezing, squeeze back
return self.value_head.denormalize(vpred)[:, 0]
class InverseActionNet(MinecraftPolicy):
"""
Args:
conv3d_params: PRE impala 3D CNN params. They are just passed into th.nn.Conv3D.
"""
def __init__(
self,
hidsize=512,
conv3d_params=None,
**MCPoliy_kwargs,
):
super().__init__(
hidsize=hidsize,
# If we're using 3dconv, then we normalize entire impala otherwise don't
# normalize the first impala layer since we normalize the input
first_conv_norm=conv3d_params is not None,
**MCPoliy_kwargs,
)
self.conv3d_layer = None
if conv3d_params is not None:
# 3D conv is the first layer, so don't normalize its input
conv3d_init_params = deepcopy(self.init_norm_kwargs)
conv3d_init_params["group_norm_groups"] = None
conv3d_init_params["batch_norm"] = False
self.conv3d_layer = FanInInitReLULayer(
layer_type="conv3d",
log_scope="3d_conv",
**conv3d_params,
**conv3d_init_params,
)
def forward(self, ob, state_in, context):
first = context["first"]
x = self.img_preprocess(ob["img"])
# Conv3D Prior to Impala
if self.conv3d_layer is not None:
x = self._conv3d_forward(x)
# Impala Stack
x = self.img_process(x)
if self.recurrent_layer is not None:
x, state_out = self.recurrent_layer(x, first, state_in)
x = F.relu(x, inplace=False)
pi_latent = self.lastlayer(x)
pi_latent = self.final_ln(x)
return (pi_latent, None), state_out
def _conv3d_forward(self, x):
# Convert from (B, T, H, W, C) -> (B, H, W, C, T)
x = transpose(x, "bthwc", "bcthw")
new_x = []
for mini_batch in th.split(x, 1):
new_x.append(self.conv3d_layer(mini_batch))
x = th.cat(new_x)
# Convert back
x = transpose(x, "bcthw", "bthwc")
return x
class InverseActionPolicy(nn.Module):
def __init__(
self,
action_space,
pi_head_kwargs=None,
idm_net_kwargs=None,
):
super().__init__()
self.action_space = action_space
self.net = InverseActionNet(**idm_net_kwargs)
pi_out_size = self.net.output_latent_size()
pi_head_kwargs = {} if pi_head_kwargs is None else pi_head_kwargs
self.pi_head = self.make_action_head(pi_out_size=pi_out_size, **pi_head_kwargs)
def make_action_head(self, **kwargs):
return make_action_head(self.action_space, **kwargs)
def reset_parameters(self):
super().reset_parameters()
self.net.reset_parameters()
self.pi_head.reset_parameters()
def forward(self, obs, first: th.Tensor, state_in, **kwargs):
if isinstance(obs, dict):
# We don't want to mutate the obs input.
obs = obs.copy()
# If special "mask" key is in obs,
# It's for masking the logits.
# We take it out (the network doesn't need it)
mask = obs.pop("mask", None)
else:
mask = None
(pi_h, _), state_out = self.net(obs, state_in=state_in, context={"first": first}, **kwargs)
pi_logits = self.pi_head(pi_h, mask=mask)
return (pi_logits, None, None), state_out
@th.no_grad()
def predict(
self,
obs,
deterministic: bool = True,
**kwargs,
):
(pd, _, _), state_out = self(obs=obs, **kwargs)
ac = self.pi_head.sample(pd, deterministic=deterministic)
log_prob = self.pi_head.logprob(ac, pd)
assert not th.isnan(log_prob).any()
result = {"log_prob": log_prob, "pd": pd}
return ac, state_out, result
def initial_state(self, batch_size: int):
return self.net.initial_state(batch_size)
================================================
FILE: metrics/IDM/lib/scaled_mse_head.py
================================================
from typing import Dict, Optional
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from lib.action_head import fan_in_linear
from lib.normalize_ewma import NormalizeEwma
class ScaledMSEHead(nn.Module):
"""
Linear output layer that scales itself so that targets are always normalized to N(0, 1)
"""
def __init__(
self, input_size: int, output_size: int, norm_type: Optional[str] = "ewma", norm_kwargs: Optional[Dict] = None
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.norm_type = norm_type
self.linear = nn.Linear(self.input_size, self.output_size)
norm_kwargs = {} if norm_kwargs is None else norm_kwargs
self.normalizer = NormalizeEwma(output_size, **norm_kwargs)
def reset_parameters(self):
init.orthogonal_(self.linear.weight)
fan_in_linear(self.linear)
self.normalizer.reset_parameters()
def forward(self, input_data):
return self.linear(input_data)
def loss(self, prediction, target):
"""
Calculate the MSE loss between output and a target.
'Prediction' has to be normalized while target is denormalized.
Loss is calculated in a 'normalized' space.
"""
return F.mse_loss(prediction, self.normalizer(target), reduction="mean")
def denormalize(self, input_data):
"""Convert input value from a normalized space into the original one"""
return self.normalizer.denormalize(input_data)
def normalize(self, input_data):
return self.normalizer(input_data)
================================================
FILE: metrics/IDM/lib/torch_util.py
================================================
import functools
import itertools
import math
import os
import pickle
import re
import subprocess
import tempfile
from contextlib import contextmanager
from hashlib import md5, sha1
import numpy as np
import torch as th
import torch.distributed as dist
import torch.distributions as dis
import torch.nn.functional as F
from torch import nn
import lib.tree_util as tree_util
from lib import misc
def contextmanager_to_decorator(cm):
def decorator(fn):
@functools.wraps(fn)
def newfn(*args, **kwargs):
with cm():
return fn(*args, **kwargs)
return newfn
return decorator
def have_cuda():
return th.has_cuda
def default_device_type():
return "cuda" if have_cuda() else "cpu"
no_grad = contextmanager_to_decorator(th.no_grad)
DEFAULT_DEVICE = th.device(type=default_device_type())
def set_default_torch_device(device):
global DEFAULT_DEVICE
DEFAULT_DEVICE = th.device(device)
def dev():
return DEFAULT_DEVICE
def zeros(*args, **kwargs):
return th.zeros(*args, **kwargs, device=dev())
def ones(*args, **kwargs):
return th.ones(*args, **kwargs, device=dev())
def arange(*args, **kwargs):
return th.arange(*args, **kwargs, device=dev())
def NormedLinear(*args, scale=1.0, dtype=th.float32, **kwargs):
"""
nn.Linear but with normalized fan-in init
"""
dtype = parse_dtype(dtype)
if dtype == th.float32:
out = nn.Linear(*args, **kwargs)
elif dtype == th.float16:
out = LinearF16(*args, **kwargs)
else:
raise ValueError(dtype)
out.weight.data *= scale / out.weight.norm(dim=1, p=2, keepdim=True)
if kwargs.get("bias", True):
out.bias.data *= 0
return out
class LinearF16(nn.Linear):
def forward(self, x):
return F.linear(x, self.weight.half(), self.bias.half() if self.bias is not None else None)
class LayerNormF16(nn.LayerNorm):
def forward(self, x):
return F.layer_norm(x, self.normalized_shape, self.weight.half(), self.bias.half(), self.eps)
def LayerNorm(*args, dtype=th.float32, **kwargs):
dtype = parse_dtype(dtype)
if dtype == th.float32:
out = nn.LayerNorm(*args, **kwargs)
elif dtype == th.float16:
out = LayerNormF16(*args, **kwargs)
else:
raise ValueError(dtype)
out.weight.no_scale = True
return out
def flatten_image(x):
"""
Flattens last three dims
"""
*batch_shape, h, w, c = x.shape
return x.reshape((*batch_shape, h * w * c))
def sequential(layers, x, *args, diag_name=None, use_checkpoint=False):
for (i, layer) in enumerate(layers):
x = layer(x, *args)
return x
@no_grad
def load_average_with_metadata(paths, overrides):
n_models = len(paths)
model, metadata = load_with_metadata(paths[0], overrides=overrides)
for p in model.parameters():
p.mul_(1 / n_models)
for p in paths[1:]:
new_model, _ = load_with_metadata(p, overrides=overrides)
for (n1, p1), (n2, p2) in misc.safezip(model.named_parameters(), new_model.named_parameters()):
assert n1 == n2, f"names {n1} and {n2} don't match"
p1.add_(p2.mul_(1 / n_models))
return model, metadata
def save_kwargs(fn):
"""
This decorator passes through the user-provided kwargs and adds one more, called
save_kwargs, mapping to {"create_fn" : name_of_decorated_fn, "kwargs" : other_kwargs}
You put on this decorator on a function that creates a pytorch module. This will
save the kwargs and the function that was used to create the module.
This lets us restore the model state later.
"""
@functools.wraps(fn)
def wrapper(**kwargs):
if "save_kwargs" in kwargs:
return fn(**kwargs)
else:
sk = {**kwargs, "create_fn": f"{fn.__module__}:{fn.__name__}"}
return fn(save_kwargs=sk, **kwargs)
return wrapper
def parse_dtype(x):
if isinstance(x, th.dtype):
return x
elif isinstance(x, str):
if x == "float32" or x == "float":
return th.float32
elif x == "float64" or x == "double":
return th.float64
elif x == "float16" or x == "half":
return th.float16
elif x == "uint8":
return th.uint8
elif x == "int8":
return th.int8
elif x == "int16" or x == "short":
return th.int16
elif x == "int32" or x == "int":
return th.int32
elif x == "int64" or x == "long":
return th.int64
elif x == "bool":
return th.bool
else:
raise ValueError(f"cannot parse {x} as a dtype")
else:
raise TypeError(f"cannot parse {type(x)} as dtype")
def index(x, i):
"""
Batched, broadcasting index of x along dimension i.ndim.
For example, if x has shape (1, 2, 3, 4, 5) and i has shape (1, 1, 3)
then the result has shape (1, 2, 3, 5) and each value in i must be between 0 and 3.
"""
assert x.ndim >= i.ndim + 1
gather_dim = i.ndim
while i.ndim < x.ndim:
i = i.unsqueeze(-1)
expand_shape = list(x.shape)
expand_shape[gather_dim] = 1
i = i.expand(*expand_shape)
xi = th.gather(x, gather_dim, i)
assert xi.shape[gather_dim] == 1
return xi.squeeze(gather_dim)
================================================
FILE: metrics/IDM/lib/tree_util.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copied this from jax, made it self-contained
# Currently just used for improved_checkpoint
import collections
import functools
import itertools as it
from collections.abc import Collection
from typing import Dict, List, Optional
def unzip2(xys):
xs = []
ys = []
for x, y in xys:
xs.append(x)
ys.append(y)
return tuple(xs), tuple(ys)
def partial(fun, *args, **kwargs):
wrapped = functools.partial(fun, *args, **kwargs)
functools.update_wrapper(wrapped, fun)
wrapped._bound_args = args # pylint: disable=protected-access
return wrapped
def safe_zip(*args: Collection) -> List[tuple]:
n = len(args[0])
for arg in args[1:]:
assert len(arg) == n, "length mismatch: {}".format(list(map(len, args)))
return list(zip(*args))
def safe_map(f, *args):
args = list(map(list, args))
n = len(args[0])
for arg in args[1:]:
assert len(arg) == n, "length mismatch: {}".format(list(map(len, args)))
return list(map(f, *args))
def tree_map(f, tree, treat_as_leaves: Optional[List] = None):
"""Map a function over a pytree to produce a new pytree.
Args:
f: function to be applied at each leaf.
tree: a pytree to be mapped over.
Returns:
A new pytree with the same structure as `tree` but with the value at each
leaf given by `f(x)` where `x` is the value at the corresponding leaf in
`tree`.
"""
if treat_as_leaves is None:
treat_as_leaves = []
node_type = node_types.get(type(tree))
if node_type and type(tree) not in treat_as_leaves:
children, node_spec = node_type.to_iterable(tree)
new_children = [tree_map(f, child, treat_as_leaves) for child in children]
return node_type.from_iterable(node_spec, new_children)
else:
return f(tree)
def tree_multimap(f, tree, *rest, treat_as_leaves: Optional[List] = None):
"""Map a multi-input function over pytree args to produce a new pytree.
Args:
f: function that takes `1 + len(rest)` arguments, to be applied at the
corresponding leaves of the pytrees.
tree: a pytree to be mapped over, with each leaf providing the first
positional argument to `f`.
*rest: a tuple of pytrees, each with the same structure as `tree`.
Returns:
A new pytree with the same structure as `tree` but with the value at each
leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf
in `tree` and `xs` is the tuple of values at corresponding leaves in `rest`.
"""
if treat_as_leaves is None:
treat_as_leaves = []
node_type = node_types.get(type(tree))
if node_type and type(tree) not in treat_as_leaves:
children, node_spec = node_type.to_iterable(tree)
all_children = [children]
for other_tree in rest:
other_children, other_node_data = node_type.to_iterable(other_tree)
if other_node_data != node_spec:
raise TypeError("Mismatch: {} != {}".format(other_node_data, node_spec))
all_children.append(other_children)
new_children = [tree_multimap(f, *xs, treat_as_leaves=treat_as_leaves) for xs in zip(*all_children)]
return node_type.from_iterable(node_spec, new_children)
else:
return f(tree, *rest)
def prefix_multimap(f, treedef, tree, *rest):
"""Like tree_multimap but only maps down through a tree prefix."""
if isinstance(treedef, PyLeaf):
return f(tree, *rest)
else:
node_type = node_types.get(type(tree))
if node_type != treedef.node_type:
raise TypeError("Mismatch: {} != {}".format(treedef.node_type, node_type))
children, node_data = node_type.to_iterable(tree)
if node_data != treedef.node_data:
raise TypeError("Mismatch: {} != {}".format(treedef.node_data, node_data))
all_children = [children]
for other_tree in rest:
other_children, other_node_data = node_type.to_iterable(other_tree)
if other_node_data != node_data:
raise TypeError("Mismatch: {} != {}".format(other_node_data, node_data))
all_children.append(other_children)
all_children = zip(*all_children)
new_children = [prefix_multimap(f, td, *xs) for td, xs in zip(treedef.children, all_children)]
return node_type.from_iterable(node_data, new_children)
def walk_pytree(f_node, f_leaf, tree, treat_as_leaves: Optional[List] = None):
node_type = node_types.get(type(tree))
if treat_as_leaves is None:
treat_as_leaves = []
if node_type and type(tree) not in treat_as_leaves:
children, node_spec = node_type.to_iterable(tree)
proc_children, child_specs = unzip2([walk_pytree(f_node, f_leaf, child, treat_as_leaves) for child in children])
tree_def = PyTreeDef(node_type, node_spec, child_specs)
return f_node(proc_children), tree_def
else:
return f_leaf(tree), PyLeaf()
def build_tree(treedef, xs):
if isinstance(treedef, PyLeaf):
return xs
else:
# We use 'iter' for clearer error messages
children = safe_map(build_tree, iter(treedef.children), iter(xs))
return treedef.node_type.from_iterable(treedef.node_data, children)
def _tree_unflatten(xs, treedef):
if isinstance(treedef, PyLeaf):
return next(xs)
else:
children = safe_map(partial(_tree_unflatten, xs), treedef.children)
return treedef.node_type.from_iterable(treedef.node_data, children)
def _num_leaves(treedef):
return 1 if isinstance(treedef, PyLeaf) else sum(safe_map(_num_leaves, treedef.children))
def _nested_treedef(inner, outer):
# just used in tree_transpose error checking
if isinstance(outer, PyLeaf):
return inner
else:
children = safe_map(partial(_nested_treedef, inner), outer.children)
return PyTreeDef(outer.node_type, outer.node_data, tuple(children))
class PyTreeDef(object):
def __init__(self, node_type, node_data, children):
self.node_type = node_type
self.node_data = node_data
self.children = children
def __repr__(self):
if self.node_data is None:
data_repr = ""
else:
data_repr = "[{}]".format(self.node_data)
return "PyTree({}{}, [{}])".format(self.node_type.name, data_repr, ",".join(safe_map(repr, self.children)))
def __hash__(self):
return hash((self.node_type, self.node_data, tuple(self.children)))
def __eq__(self, other):
if isinstance(other, PyLeaf):
return False
else:
return self.node_type == other.node_type and self.node_data == other.node_data and self.children == other.children
def __ne__(self, other):
return not self == other
class PyLeaf(object):
def __repr__(self):
return "*"
def __eq__(self, other):
return isinstance(other, PyLeaf)
class NodeType(object):
def __init__(self, name, to_iterable, from_iterable):
self.name = name
self.to_iterable = to_iterable
self.from_iterable = from_iterable
node_types: Dict[type, NodeType] = {}
def register_pytree_node(py_type, to_iterable, from_iterable):
assert py_type not in node_types
node_types[py_type] = NodeType(str(py_type), to_iterable, from_iterable)
def tuple_to_iterable(xs):
return xs, None
def tuple_from_iterable(_keys, xs):
return tuple(xs)
def list_to_iterable(xs):
return tuple(xs), None
def list_from_iterable(_keys, xs):
return list(xs)
def dict_to_iterable(xs):
keys = tuple(sorted(xs.keys()))
return tuple(map(xs.get, keys)), keys
def dict_from_iterable(keys, xs):
return dict(safe_zip(keys, xs))
def ordered_dict_from_iterable(keys, xs):
return collections.OrderedDict(safe_zip(keys, xs))
def default_dict_to_iterable(xs):
return (tuple(xs.values()), (xs.default_factory, tuple(xs.keys())))
def default_dict_from_iterable(keys, xs):
return collections.defaultdict(keys[0], safe_zip(keys[1], xs))
def none_to_iterable(_xs):
return (), None
def none_from_iterable(_keys, _xs):
return None
register_pytree_node(tuple, tuple_to_iterable, tuple_from_iterable)
register_pytree_node(list, list_to_iterable, list_from_iterable)
register_pytree_node(dict, dict_to_iterable, dict_from_iterable)
register_pytree_node(collections.OrderedDict, dict_to_iterable, ordered_dict_from_iterable)
register_pytree_node(collections.defaultdict, default_dict_to_iterable, default_dict_from_iterable)
register_pytree_node(type(None), none_to_iterable, none_from_iterable)
================================================
FILE: metrics/IDM/lib/util.py
================================================
from typing import Dict, Optional
import torch as th
from torch import nn
from torch.nn import functional as F
import lib.torch_util as tu
from lib.masked_attention import MaskedAttention
from lib.minecraft_util import store_args
from lib.tree_util import tree_map
def get_module_log_keys_recursive(m: nn.Module):
"""Recursively get all keys that a module and its children want to log."""
keys = []
if hasattr(m, "get_log_keys"):
keys += m.get_log_keys()
for c in m.children():
keys += get_module_log_keys_recursive(c)
return keys
class FanInInitReLULayer(nn.Module):
"""Implements a slightly modified init that correctly produces std 1 outputs given ReLU activation
:param inchan: number of input channels
:param outchan: number of output channels
:param layer_args: positional layer args
:param layer_type: options are "linear" (dense layer), "conv" (2D Convolution), "conv3d" (3D convolution)
:param init_scale: multiplier on initial weights
:param batch_norm: use batch norm after the layer (for 2D data)
:param group_norm_groups: if not None, use group norm with this many groups after the layer. Group norm 1
would be equivalent of layernorm for 2D data.
:param layer_norm: use layernorm after the layer (for 1D data)
:param layer_kwargs: keyword arguments for the layer
"""
@store_args
def __init__(
self,
inchan: int,
outchan: int,
*layer_args,
layer_type: str = "conv",
init_scale: int = 1,
batch_norm: bool = False,
batch_norm_kwargs: Dict = {},
group_norm_groups: Optional[int] = None,
layer_norm: bool = False,
use_activation=True,
log_scope: Optional[str] = None,
**layer_kwargs,
):
super().__init__()
# Normalization
self.norm = None
if batch_norm:
self.norm = nn.BatchNorm2d(inchan, **batch_norm_kwargs)
elif group_norm_groups is not None:
self.norm = nn.GroupNorm(group_norm_groups, inchan)
elif layer_norm:
self.norm = nn.LayerNorm(inchan)
layer = dict(conv=nn.Conv2d, conv3d=nn.Conv3d, linear=nn.Linear)[layer_type]
self.layer = layer(inchan, outchan, bias=self.norm is None, *layer_args, **layer_kwargs)
# Init Weights (Fan-In)
self.layer.weight.data *= init_scale / self.layer.weight.norm(
dim=tuple(range(1, self.layer.weight.data.ndim)), p=2, keepdim=True
)
# Init Bias
if self.layer.bias is not None:
self.layer.bias.data *= 0
def forward(self, x):
"""Norm after the activation. Experimented with this for both IAM and BC and it was slightly better."""
if self.norm is not None:
x = self.norm(x)
x = self.layer(x)
if self.use_activation:
x = F.relu(x, inplace=True)
return x
def get_log_keys(self):
return [
f"activation_mean/{self.log_scope}",
f"activation_std/{self.log_scope}",
]
class ResidualRecurrentBlocks(nn.Module):
@store_args
def __init__(
self,
n_block=2,
recurrence_type="multi_layer_lstm",
is_residual=True,
**block_kwargs,
):
super().__init__()
init_scale = n_block ** -0.5 if is_residual else 1
self.blocks = nn.ModuleList(
[
ResidualRecurrentBlock(
**block_kwargs,
recurrence_type=recurrence_type,
is_residual=is_residual,
init_scale=init_scale,
block_number=i,
)
for i in range(n_block)
]
)
def forward(self, x, first, state):
state_out = []
assert len(state) == len(
self.blocks
), f"Length of state {len(state)} did not match length of blocks {len(self.blocks)}"
for block, _s_in in zip(self.blocks, state):
x, _s_o = block(x, first, _s_in)
state_out.append(_s_o)
return x, state_out
def initial_state(self, batchsize):
if "lstm" in self.recurrence_type:
return [None for b in self.blocks]
else:
return [b.r.initial_state(batchsize) for b in self.blocks]
class ResidualRecurrentBlock(nn.Module):
@store_args
def __init__(
self,
hidsize,
timesteps,
init_scale=1,
recurrence_type="multi_layer_lstm",
is_residual=True,
use_pointwise_layer=True,
pointwise_ratio=4,
pointwise_use_activation=False,
attention_heads=8,
attention_memory_size=2048,
attention_mask_style="clipped_causal",
log_scope="resblock",
block_number=0,
):
super().__init__()
self.log_scope = f"{log_scope}{block_number}"
s = init_scale
if use_pointwise_layer:
if is_residual:
s *= 2 ** -0.5 # second residual
self.mlp0 = FanInInitReLULayer(
hidsize,
hidsize * pointwise_ratio,
init_scale=1,
layer_type="linear",
layer_norm=True,
log_scope=self.log_scope + "/ptwise_mlp0",
)
self.mlp1 = FanInInitReLULayer(
hidsize * pointwise_ratio,
hidsize,
init_scale=s,
layer_type="linear",
use_activation=pointwise_use_activation,
log_scope=self.log_scope + "/ptwise_mlp1",
)
self.pre_r_ln = nn.LayerNorm(hidsize)
if recurrence_type in ["multi_layer_lstm", "multi_layer_bilstm"]:
self.r = nn.LSTM(hidsize, hidsize, batch_first=True)
nn.init.normal_(self.r.weight_hh_l0, std=s * (self.r.weight_hh_l0.shape[0] ** -0.5))
nn.init.normal_(self.r.weight_ih_l0, std=s * (self.r.weight_ih_l0.shape[0] ** -0.5))
self.r.bias_hh_l0.data *= 0
self.r.bias_ih_l0.data *= 0
elif recurrence_type == "transformer":
self.r = MaskedAttention(
input_size=hidsize,
timesteps=timesteps,
memory_size=attention_memory_size,
heads=attention_heads,
init_scale=s,
norm="none",
log_scope=log_scope + "/sa",
use_muP_factor=True,
mask=attention_mask_style,
)
def forward(self, x, first, state):
residual = x
x = self.pre_r_ln(x)
x, state_out = recurrent_forward(
self.r,
x,
first,
state,
reverse_lstm=self.recurrence_type == "multi_layer_bilstm" and (self.block_number + 1) % 2 == 0,
)
if self.is_residual and "lstm" in self.recurrence_type: # Transformer already residual.
x = x + residual
if self.use_pointwise_layer:
# Residual MLP
residual = x
x = self.mlp1(self.mlp0(x))
if self.is_residual:
x = x + residual
return x, state_out
def recurrent_forward(module, x, first, state, reverse_lstm=False):
if isinstance(module, nn.LSTM):
if state is not None:
# In case recurrent models do not accept a "first" argument we zero out the hidden state here
mask = 1 - first[:, 0, None, None].to(th.float)
state = tree_map(lambda _s: _s * mask, state)
state = tree_map(lambda _s: _s.transpose(0, 1), state) # NL, B, H
if reverse_lstm:
x = th.flip(x, [1])
x, state_out = module(x, state)
if reverse_lstm:
x = th.flip(x, [1])
state_out = tree_map(lambda _s: _s.transpose(0, 1), state_out) # B, NL, H
return x, state_out
else:
return module(x, first, state)
def _banded_repeat(x, t):
"""
Repeats x with a shift.
For example (ignoring the batch dimension):
_banded_repeat([A B C D E], 4)
=
[D E 0 0 0]
[C D E 0 0]
[B C D E 0]
[A B C D E]
"""
b, T = x.shape
x = th.cat([x, x.new_zeros(b, t - 1)], dim=1)
result = x.unfold(1, T, 1).flip(1)
return result
def bandify(b_nd, t, T):
"""
b_nd -> D_ntT, where
"n" indexes over basis functions
"d" indexes over time differences
"t" indexes over output time
"T" indexes over input time
only t >= T is nonzero
B_ntT[n, t, T] = b_nd[n, t - T]
"""
nbasis, bandsize = b_nd.shape
b_nd = b_nd[:, th.arange(bandsize - 1, -1, -1)]
if bandsize >= T:
b_nT = b_nd[:, -T:]
else:
b_nT = th.cat([b_nd.new_zeros(nbasis, T - bandsize), b_nd], dim=1)
D_tnT = _banded_repeat(b_nT, t)
return D_tnT
def get_norm(name, d, dtype=th.float32):
if name == "none":
return lambda x: x
elif name == "layer":
return tu.LayerNorm(d, dtype=dtype)
else:
raise NotImplementedError(name)
================================================
FILE: metrics/IDM/lib/xf.py
================================================
"""
Implementation of transformer and reshaping-based sparse transformer
"""
import functools
import math
import torch as th
from torch import nn
from torch.nn import functional as F
from lib import misc, mlp
from lib import torch_util as tu
from lib import util
SENTINEL = 0.1337
def attention(
Q_bte,
K_bTe,
V_bTe,
dtype,
mask=True,
extra_btT=None,
maxlen=None,
check_sentinel=False,
use_muP_factor=False,
):
"""
performs softmax(Q*K)*V operation
t : output (write) time axis, possibly size=1 for just the last timestep
T : input (read) time axis
t < T is OK
'check_sentinel' is used when you want to make it impossible to attend to certain keys.
All keys where every value is equal to the constant SENTINEL will be ignored.
Currently this is only used by StridedAttn.
"""
assert Q_bte.dtype == K_bTe.dtype == dtype, f"{Q_bte.dtype}, {K_bTe.dtype}, {dtype} must all match"
e = Q_bte.shape[2]
if check_sentinel:
invalid = (K_bTe == SENTINEL).int().sum(dim=-1) == e
invalid = misc.reshape(invalid, "b, T", "b, 1, T")
if isinstance(mask, th.Tensor):
bias = (~mask).float() * -1e9
elif mask:
bias = get_attn_bias_cached(Q_bte.shape[1], K_bTe.shape[1], maxlen=maxlen, device=Q_bte.device, dtype=th.float32)
else:
bias = Q_bte.new_zeros((), dtype=th.float32)
if extra_btT is not None:
bias = bias + extra_btT
# Equivalent to bias + (1 / math.sqrt(e)) * th.einsum("bte,bpe->btp", Q_bte, K_bte)
# but faster:
logit_btT = th.baddbmm(
bias,
Q_bte.float(),
K_bTe.float().transpose(-1, -2),
alpha=(1 / e) if use_muP_factor else (1 / math.sqrt(e)),
)
if check_sentinel:
logit_btT = logit_btT - 1e9 * invalid.float()
W_btT = th.softmax(logit_btT, dim=2).to(dtype)
if callable(V_bTe):
# This is used by the sharded video model to defer waiting on
# the broadcast of the values until they're needed
V_bTe = V_bTe()
# th.einsum only lets you use lowercase letters, so 'p' for 'past'
# means 'T'
A_bte = th.einsum("btp,bpe->bte", W_btT, V_bTe)
return A_bte
class Attn:
"""
Defines an attention mechanism
All the mechanisms here can be defined by two operations:
1. preprocessing Q,K,V,R[=relative attention query]
to move axes from embedding dimension to
batch dimension, and possibly doing shifts.
2. postprocessing the final result to move axes back to embedding
axis.
"""
def __init__(self, mask, maxlen):
self.mask = mask
self.maxlen = maxlen
def preproc_qkv(self, Q_bte, K_bte, V_bte):
raise NotImplementedError
def preproc_r(self, R_btn):
raise NotImplementedError
def split_heads(x_bte, h):
b, t, e = x_bte.shape
assert e % h == 0, "Embsize must be divisible by number of heads"
q = e // h
x_bthq = x_bte.reshape((b, t, h, q))
x_bhtq = misc.transpose(x_bthq, "bthq", "bhtq")
x_Btq = x_bhtq.reshape((b * h, t, q))
return x_Btq
class All2All(Attn):
def __init__(self, nhead, maxlen, mask=True, head_dim=None):
super().__init__(mask=mask, maxlen=maxlen)
assert (nhead is None) != (head_dim is None), "exactly one of nhead and head_dim must be specified"
self.h = nhead
self.head_dim = head_dim
def preproc_qkv(self, *xs):
q = xs[0].shape[-1]
for x in xs:
assert x.shape[-1] == q, "embedding dimensions do not match"
h = self.h or misc.exact_div(q, self.head_dim)
postproc = functools.partial(self.postproc_a, h=h)
return (postproc, *tuple(split_heads(x, h) for x in xs))
def preproc_r(self, R_btn):
_, ret = self.preproc_qkv(R_btn)
return ret
def postproc_a(self, A_Btq, h):
B, t, q = A_Btq.shape
b = B // h
A_bhtq = A_Btq.reshape((b, h, t, q))
A_bthq = misc.transpose(A_bhtq, "bhtq", "bthq")
A_bte = A_bthq.reshape((b, t, h * q))
return A_bte
def _required_padding(dim, target_div):
if dim % target_div == 0:
return 0
else:
return target_div - dim % target_div
class StridedAttn(Attn):
def __init__(self, nhead, stride, maxlen, mask=True):
super().__init__(mask=mask, maxlen=maxlen)
self.h = nhead
self.stride = stride
def _preproc(self, x, name, Q_t=None, Q_pad=None):
x, undo = misc.reshape_undo(x, "b, t*stride, e", "b, 1, t, stride*e", stride=self.stride)
if name == "Q":
Q_pad = _required_padding(x.shape[2], self.maxlen)
original_t = x.shape[2]
x = F.pad(x, (0, 0, 0, Q_pad), value=SENTINEL)
undo = misc.compose_undo(undo, lambda x: x[:, :, :original_t])
if name == "Q":
Q_t = x.shape[2]
assert Q_t % self.maxlen == 0, f"{Q_t} % {self.maxlen} != 0"
else:
required_len = Q_t + self.maxlen
if x.shape[2] < required_len:
x = F.pad(x, (0, 0, required_len - x.shape[2], 0), value=SENTINEL)
assert x.shape[2] >= required_len
back = x[:, :, -Q_t - self.maxlen : -self.maxlen]
front = x[:, :, -Q_t:]
x = th.cat([back, front], dim=1)
_, _, t, _ = x.shape
assert t == Q_t, f"{t} != {Q_t}"
x, undo = misc.reshape_undo(
x,
"b, pad_shift, t*maxlen, stride*h*q",
"b, pad_shift, t, maxlen, stride, h, q",
maxlen=self.maxlen,
h=self.h,
stride=self.stride,
undo=undo,
)
x, undo = misc.transpose_undo(x, "bptmshq", "bthspmq", undo=undo)
x, undo = misc.reshape_undo(
x,
"b, t, h, stride, pad_shift, maxlen, q",
"b*t*h*stride, pad_shift*maxlen, q",
undo=undo,
)
if name == "Q":
return x, undo, Q_t, Q_pad
else:
return x
def preproc_qkv(self, Q_bte, K_bte, V_bte):
pad = _required_padding(Q_bte.shape[1], self.stride)
if pad:
Q_bte = F.pad(Q_bte, (0, 0, 0, pad), value=SENTINEL)
K_bte = F.pad(K_bte, (0, 0, 0, pad), value=SENTINEL) if K_bte is not None else None
V_bte = F.pad(V_bte, (0, 0, 0, pad), value=SENTINEL) if V_bte is not None else None
undo = lambda x, pad=pad: x[:, :-pad]
else:
undo = None
if K_bte is not None:
pad = _required_padding(K_bte.shape[1], self.stride)
if pad:
K_bte = F.pad(K_bte, (0, 0, pad, 0), value=SENTINEL)
V_bte = F.pad(V_bte, (0, 0, pad, 0), value=SENTINEL)
assert Q_bte.shape[1] % self.stride == 0
assert K_bte is None or K_bte.shape[1] % self.stride == 0
assert V_bte is None or V_bte.shape[1] % self.stride == 0
Q, postproc, Q_t, Q_pad = self._preproc(Q_bte, "Q")
postproc = misc.compose_undo(undo, postproc)
return (
postproc,
Q,
self._preproc(K_bte, "K", Q_t=Q_t, Q_pad=Q_pad) if K_bte is not None else None,
self._preproc(V_bte, "V", Q_t=Q_t, Q_pad=Q_pad) if V_bte is not None else None,
)
def preproc_r(self, R_bte):
_, R, _, _ = self.preproc_qkv(R_bte, None, None)
return R
Q_SCALE = 0.1
K_SCALE = 0.2
V_SCALE = 1.0
PROJ_SCALE = 1.0
MLP0_SCALE = 1.0
MLP1_SCALE = 1.0
R_SCALE = 0.1
B_SCALE = 0.2
class AttentionLayerBase(nn.Module):
def __init__(
self,
*,
attn,
scale,
x_size,
c_size,
qk_size,
v_size,
dtype,
relattn=False,
seqlens=None,
separate=False,
):
super().__init__()
dtype = tu.parse_dtype(dtype)
self.attn = attn
self.x_size = x_size
self.c_size = c_size
s = math.sqrt(scale)
separgs = dict(seqlens=seqlens, separate=separate)
self.q_layer = MultiscaleLinear(x_size, qk_size, name="q", scale=Q_SCALE, dtype=dtype, **separgs)
self.k_layer = MultiscaleLinear(c_size, qk_size, name="k", scale=K_SCALE, bias=False, dtype=dtype, **separgs)
self.v_layer = MultiscaleLinear(c_size, v_size, name="v", scale=V_SCALE * s, bias=False, dtype=dtype, **separgs)
self.proj_layer = MultiscaleLinear(v_size, x_size, name="proj", scale=PROJ_SCALE * s, dtype=dtype, **separgs)
self.relattn = relattn
maxlen = attn.maxlen
assert maxlen > 0 or not attn.mask
if self.relattn:
nbasis = 10
self.r_layer = tu.NormedLinear(x_size, nbasis * attn.h, scale=R_SCALE, dtype=dtype)
self.b_nd = nn.Parameter(th.randn(nbasis, maxlen) * B_SCALE)
self.maxlen = maxlen
self.dtype = dtype
def relattn_logits(self, X_bte, T):
R_btn = self.r_layer(X_bte).float()
R_btn = self.attn.preproc_r(R_btn)
t = R_btn.shape[1]
D_ntT = util.bandify(self.b_nd, t, T)
extra_btT = th.einsum("btn,ntp->btp", R_btn, D_ntT)
return extra_btT
def quick_gelu(x):
return x * th.sigmoid(1.702 * x)
def act(actname, x):
if actname == "relu":
return F.relu(x)
elif actname == "gelu":
return quick_gelu(x)
elif actname == "none":
return x
else:
raise NotImplementedError(actname)
class SelfAttentionLayer(AttentionLayerBase):
"""
Residual attention layer that takes a single tensor x and has it attend to itself
Has the form
output = x + f(x)
"""
def __init__(
self,
x_size,
attn,
scale,
dtype="float32",
norm="layer",
cache_keep_len=None,
relattn=False,
log_scope="sa",
use_muP_factor=False,
**kwargs,
):
super().__init__(
x_size=x_size,
c_size=x_size,
qk_size=x_size,
v_size=x_size,
attn=attn,
scale=scale,
relattn=relattn,
dtype=dtype,
**kwargs,
)
self.ln_x = util.get_norm(norm, x_size, dtype=dtype)
if cache_keep_len is None:
if hasattr(attn, "cache_keep_len"):
cache_keep_len = attn.cache_keep_len
else:
if isinstance(attn, StridedAttn):
stride = attn.stride
else:
stride = 1
cache_keep_len = stride * attn.maxlen
self.cache_keep_len = cache_keep_len
self.log_scope = log_scope
self.use_muP_factor = use_muP_factor
def residual(self, X_bte, state):
X_bte = self.ln_x(X_bte)
Q_bte = self.q_layer(X_bte)
K_bte = self.k_layer(X_bte)
V_bte = self.v_layer(X_bte)
if state:
state, K_bte, V_bte = self.update_state(state, K_bte, V_bte)
postproc_closure, Q_bte, K_bte, V_bte = self.attn.preproc_qkv(Q_bte, K_bte, V_bte)
extra_btT = self.relattn_logits(X_bte, K_bte.shape[1]) if self.relattn else None
A_bte = attention(
Q_bte,
K_bte,
V_bte,
mask=self.attn.mask,
extra_btT=extra_btT,
maxlen=self.maxlen,
dtype=self.dtype,
check_sentinel=isinstance(self.attn, StridedAttn),
use_muP_factor=self.use_muP_factor,
)
A_bte = postproc_closure(A_bte)
Aproj_bte = self.proj_layer(A_bte)
return Aproj_bte, state
def forward(self, X_bte, state):
R_bte, state = self.residual(X_bte, state)
return X_bte + R_bte, state
def stateless_forward(self, X_bte):
out_bte, _state = self.forward(X_bte, None)
return out_bte
def update_state(self, state, K_bte, V_bte):
def append(prev, new):
"""
Given `prev` keys from cache, and `new` keys,
returns (cache, full), where
- cache goes into the output state, length chosen so that on the
next timestep, there are enough cached timesteps to get the full
context of lenth self.maxlen.
- full is used for the current forward pass, with length chosen so
that the first timestep new[:, 0] gets to see a context of
self.maxlen.
"""
tprev = prev.shape[1]
startfull = max(tprev - self.cache_keep_len, 0)
full = th.cat([prev[:, startfull:], new], dim=1)
outstate = full[:, max(full.shape[1] - (self.cache_keep_len), 0) :]
# To see that the preceding slicing is correct, consider the case
# that maxlen==1. Then `full` only consists of `new`, and
# `outstate` is empty
return outstate, full
instate_K, instate_V = state
outstate_K, K_bte = append(instate_K, K_bte)
outstate_V, V_bte = append(instate_V, V_bte)
assert outstate_K.shape[-2] <= self.cache_keep_len
return (outstate_K, outstate_V), K_bte, V_bte
def initial_state(self, batchsize, initial_T=0):
return (
tu.zeros((batchsize, initial_T, self.x_size), dtype=self.dtype),
tu.zeros((batchsize, initial_T, self.x_size), dtype=self.dtype),
)
def empty_state(self):
return None
class PointwiseLayer(nn.Module):
"""
Residual MLP applied at each timestep
"""
def __init__(self, x_size, scale, dtype, norm, actname="relu", mlp_ratio=2):
super().__init__()
s = math.sqrt(scale)
self.ln = util.get_norm(norm, x_size, dtype=dtype)
self.mlp = mlp.MLP(
insize=x_size,
nhidlayer=1,
outsize=x_size,
hidsize=int(x_size * mlp_ratio),
hidactiv=functools.partial(act, actname),
dtype=dtype,
)
self.mlp.layers[0].weight.data *= MLP0_SCALE * s
self.mlp.layers[1].weight.data *= MLP1_SCALE * s
def residual(self, x):
x = self.ln(x)
x = self.mlp(x)
return x
def forward(self, x):
return x + self.residual(x)
def _is_separate(sep, name):
if isinstance(sep, bool):
return sep
assert isinstance(sep, set)
if name in sep:
sep.remove(name)
return True
else:
return False
def make_maybe_multiscale(make_fn, *args, seqlens, separate, name, **kwargs):
"""
This function either creates one instance of a module or creates
a separate instance of the module for each resolution of the image,
determined by the `separate` parameter. We create separate modules
if `separate` is True or if `separate` is a set containing `name`.
"""
if _is_separate(separate, name):
modules = [make_fn(*args, **kwargs) for _ in seqlens]
return SplitCallJoin(modules, seqlens)
else:
return make_fn(*args, **kwargs)
class SplitCallJoin(nn.Module):
def __init__(self, mods, seqlens):
super().__init__()
self.mods = nn.ModuleList(mods)
self.seqlens = seqlens
def forward(self, x):
tl = sum(self.seqlens)
x, undo = misc.reshape_undo(x, "..., z*tl, e", "..., z, tl, e", tl=tl)
x = list(th.split(x, self.seqlens, dim=-2))
new_x = []
for x, mod in misc.safezip(x, self.mods):
x, this_undo = misc.reshape_undo(x, "..., z, l, e", "..., z*l, e")
x = mod(x)
x = this_undo(x)
new_x.append(x)
x = th.cat(new_x, dim=-2)
x = undo(x)
return x
MultiscaleLinear = functools.partial(make_maybe_multiscale, tu.NormedLinear)
MultiscalePointwise = functools.partial(make_maybe_multiscale, PointwiseLayer)
================================================
FILE: metrics/common_metrics.py
================================================
import sys
import os
sys.path.append(os.getcwd())
from common_metrics_on_video_quality.calculate_fvd import calculate_fvd
from common_metrics_on_video_quality.calculate_lpips import calculate_lpips
from common_metrics_on_video_quality.calculate_ssim import calculate_ssim
from common_metrics_on_video_quality.calculate_psnr import calculate_psnr
import os
import cv2
import torch
import numpy as np
import argparse
import json
device = torch.device("cuda")
def load_videos_to_tensor(video_dir, number_of_videos, video_length, channel, size,video_files=None):
videos_tensor = torch.zeros(number_of_videos, video_length, channel, size[0], size[1], requires_grad=False)
if video_files is None:
video_files = [f for f in os.listdir(video_dir) if f.endswith(('.mp4'))]
video_files = sorted(video_files, key=lambda x: int(x.split("_")[-1].split(".")[0]))
video_files = video_files[:number_of_videos]
for i, video_file in enumerate(video_files):
video_path = os.path.join(video_dir, video_file)
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Failed to open video: {video_path}")
continue
frames = []
# get video total length ; our gt has 16 frame but we only use 15 frame so set video_length to 15 and start frame to 1
real_video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if real_video_length > video_length:
# set start frame to video_length - video_length
start_frame = real_video_length - video_length
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
else:
# set start frame to 0
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
while len(frames) < video_length:
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = cv2.resize(frame, (size[1], size[0])) # Resize to (height, width)
frames.append(frame)
if len(frames) < video_length:
print(f"Video {video_file} has fewer frames than expected. Expected: {video_length}, Found: {len(frames)} Exiting...")
exit(1)
cap.release()
frames_np = np.array(frames[:video_length])
frames_np = np.transpose(frames_np, (0, 3, 1, 2))
videos_tensor[i] = torch.tensor(frames_np, dtype=torch.float32) / 255.0
return videos_tensor
# python scripts/tvideo/mc/common_metrics.py --video_dir1 metrics_table_1/oasis/oasis_official_results_no_demo_1gen15_mineworld_curation --video_dir2 metrics_table_1/frame_16_curation --video_length 15 --channel 3 --size "(224,384)" --output-file test_metrics.json
def main():
parser = argparse.ArgumentParser(description="Calculate FVD for two sets of videos.")
parser.add_argument("--video_dir1", type=str, required=True, help="Path to the first directory containing videos.")
parser.add_argument("--video_dir2", type=str, required=True, help="Path to the second directory containing videos.")
parser.add_argument("--video_length", type=int, default=32, help="Number of frames to retain from each video.")
parser.add_argument("--channel", type=int, default=3, help="Number of channels in the videos (default: 3 for RGB).")
parser.add_argument("--size", type=str, default="(224,384)", help="Size of the video frames (default: 256x256).")
parser.add_argument("--output-file", type=str)
args = parser.parse_args()
args.size = eval(args.size)
print("args.size", args.size)
number_of_videos = len([f for f in os.listdir(args.video_dir1) if f.endswith(".mp4")])
video_files = [f for f in os.listdir(args.video_dir1) if f.endswith(('.mp4'))]
number_of_videos = min(500,len(video_files))
print("number_of_videos", number_of_videos)
videos1 = load_videos_to_tensor(args.video_dir1, number_of_videos, args.video_length, args.channel, args.size, video_files)
videos2 = load_videos_to_tensor(args.video_dir2, number_of_videos, args.video_length, args.channel, args.size, video_files)
print("videos1.shape", videos1.shape, "videos2.shape", videos2.shape)
device = torch.device("cuda")
print(args.output_file)
result = {}
result['fvd'] = calculate_fvd(videos1, videos2, device, method='styleganv')
# result['fvd'] = calculate_fvd(videos1, videos2, device, method='videogpt')
result['ssim'] = calculate_ssim(videos1, videos2)
result['psnr'] = calculate_psnr(videos1, videos2)
result['lpips'] = calculate_lpips(videos1, videos2, device)
lpips_mean = np.mean(list(result['lpips']['value']))
ssim_mean = np.mean(list(result['ssim']['value']))
psnr_mean = np.mean(list(result['psnr']['value']))
fvd_mean = np.mean(list(result['fvd']['value']))
data_item = {"exp_name":args.video_dir1, "fvd":fvd_mean, "lpips":lpips_mean, "ssim":ssim_mean, "psnr":psnr_mean}
print(data_item)
os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
result['mean'] = data_item
with open(args.output_file, "w") as f:
json.dump(result, f, indent=4)
print("results saved to ", args.output_file)
if __name__ == "__main__":
main()
================================================
FILE: metrics/tabulate_all_results.py
================================================
import argparse
import os
import json
import sys
import pandas as pd
import numpy as np
from rich import print
def tabluate_metrics(input_dir,output_path):
metrics_list = []
all_files = [f for f in os.listdir(input_dir) if f.endswith('.json')]
idm_results = [f for f in all_files if 'idm' in f]
fvd_results = [f for f in all_files if 'fvd' in f]
exps = set([i.replace("idm_","").replace(".json","") for i in idm_results]) & set([i.replace("fvd_","").replace(".json","") for i in fvd_results])
exps = list(exps)
print(f"[bold magenta][Tabulating Evaluation Results][/bold magenta]: Found experiments : {exps}")
for exp in exps:
idm_file = os.path.join(input_dir, f"idm_{exp}.json")
fvd_file = os.path.join(input_dir, f"fvd_{exp}.json")
with open(idm_file, 'r') as f:
idm_data = json.load(f)
with open(fvd_file, 'r') as f:
fvd_data = json.load(f)
fvd_data = fvd_data["mean"]
fvd_data.pop("exp_name", None)
idm_data = idm_data["metric_mean_on_task"]
metrics_entry = {
"experiment": exp,
}
# merge dict
metrics_entry.update(fvd_data)
metrics_entry.update(idm_data)
metrics_list.append(metrics_entry)
# Convert list of metrics to a DataFrame
df = pd.DataFrame(metrics_list)
# Save the DataFrame to a CSV file
df.to_csv(output_path, index=False)
print(f"[bold red][Tabulating Evaluation Results End][/bold red] Metrics tabulated and saved to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Tabulate metrics from JSON files")
parser.add_argument("--input_dir", type=str, required=True, help="Directory containing JSON metric files")
parser.add_argument("--output_path", type=str, required=True, help="Path to save the tabulated metrics CSV file")
args = parser.parse_args()
tabluate_metrics(args.input_dir, args.output_path)
================================================
FILE: mineworld.py
================================================
import os
import sys
sys.path.append(os.getcwd())
import gradio as gr
from PIL import Image
import numpy as np
import torch
import cv2
from utils import load_model
from omegaconf import OmegaConf
from argparse import ArgumentParser
from collections import deque
import tempfile
import atexit
from torchvision import transforms
from einops import rearrange
from mcdataset import MCDataset
import itertools
class Buttons:
ATTACK = "attack"
BACK = "back"
FORWARD = "forward"
JUMP = "jump"
LEFT = "left"
RIGHT = "right"
SNEAK = "sneak"
SPRINT = "sprint"
USE = "use"
DROP = "drop"
SWAPHANDS = "swapHands"
PICKITEM = "pickItem"
ALL = [
ATTACK,
USE,
FORWARD,
BACK,
LEFT,
RIGHT,
JUMP,
SNEAK,
SPRINT,
DROP,
SWAPHANDS,
PICKITEM,
# INVENTORY,
# ESC,
] + [f"hotbar.{i}" for i in range(1, 10)]
KEYBOARD_BUTTON_MAPPING = {
"key.keyboard.s" :"back",
"key.keyboard.q" :"drop",
"key.keyboard.w" :"forward",
"key.keyboard.1" :"hotbar.1",
"key.keyboard.2" :"hotbar.2",
"key.keyboard.3" :"hotbar.3",
"key.keyboard.4" :"hotbar.4",
"key.keyboard.5" :"hotbar.5",
"key.keyboard.6" :"hotbar.6",
"key.keyboard.7" :"hotbar.7",
"key.keyboard.8" :"hotbar.8",
"key.keyboard.9" :"hotbar.9",
"key.keyboard.space" :"jump",
"key.keyboard.a" :"left",
"key.keyboard.d" :"right",
"key.keyboard.left.shift" :"sneak",
"key.keyboard.left.control" :"sprint",
"key.keyboard.f" :"swapHands",
}
# Template action
NOOP_ACTION = {
"forward": 0,
"back": 0,
"left": 0,
"right": 0,
"jump": 0,
"attack": 0,
"use": 0,
"pickItem": 0,
"drop": 0,
"sneak": 0,
"sprint": 0,
"swapHands": 0,
"hotbar.1": 0,
"hotbar.2": 0,
"hotbar.3": 0,
"hotbar.4": 0,
"hotbar.5": 0,
"hotbar.6": 0,
"hotbar.7": 0,
"hotbar.8": 0,
"hotbar.9": 0,
"camera": np.array([0, 0]),
}
ACTION_BUTTON = {
"forward": 0,
"back": 0,
"left": 0,
"right": 0,
"attack": 0,
"sprint": 0,
"jump": 0,
"use": 0,
"drop": 0,
"hotbar.1": 0,
"pickItem": 0,
}
FOR_BACK = {
"forward": 0,
"back": 0,
}
L_R = {
"left": 0,
"right": 0,
}
ATT_USE_DROP = {
"attack": 0,
"use": 0,
"drop": 0,
}
JUMP_SPR = {
"jump": 0,
"sprint": 0,
}
HORBAR = {
"hotbar.1": 0,
"hotbar.2": 0,
"hotbar.3": 0,
"hotbar.4": 0,
"hotbar.5": 0,
"hotbar.6": 0,
"hotbar.7": 0,
"hotbar.8": 0,
"hotbar.9": 0,
}
safe_globals = {"array": np.array}
AGENT_RESOLUTION = (384, 224)
CAMERA_SCALER = 360.0 / 2400.0
TOKEN_PER_IMAGE = 336
TOKEN_PER_ACTION = 11
VIDEO_FRAMES = []
GENERATED_FILES = []
frame_cache = []
action_cache = []
last_pos = 0
MC_ACTION_MAP = MCDataset()
SHOW_FRAMES = 8
REFERENCE_FRAME = None
CONTEXT_LEN = None
DIAGD = False
WINDOWSIZE = 4
def get_args():
parser = ArgumentParser()
parser.add_argument('--scene', type=str, default='./assets/scene.mp4')
parser.add_argument('--model_ckpt', type=str, default='./checkpoints/700M_16f.pt')
parser.add_argument('--config', type=str, default='./configs/700M_16f.yaml')
parser.add_argument('--reference_frame', type=int, default=8)
parser.add_argument('--diagd', action='store_true', help='use diagd')
parser.add_argument('--window_size', type=int, default=4)
args = parser.parse_args()
return args
def make_action_dict(action_line):
action_dict = {'ESC': 0, 'back': 0, 'drop': 0, 'forward': 0, 'hotbar.1': 0, 'hotbar.2': 0, 'hotbar.3': 0, 'hotbar.4': 0, 'hotbar.5': 0, 'hotbar.6': 0, 'hotbar.7': 0, 'hotbar.8': 0, 'hotbar.9': 0, 'inventory': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0, 'swapHands': 0, 'camera': np.array([0, 0]), 'attack': 0, 'use': 0, 'pickItem': 0}
if isinstance(action_line, str):
action_line = action_line.split(",")
action_dict['camera'] = np.array((int(action_line[-2]), int(action_line[-1])))
for act in action_line:
if act in Buttons.ALL:
action_dict[act] = 1
return action_dict
def stack_images(imgs):
width, height = imgs[0].size
new_im = Image.new('RGB', (4*width, height*2))
for i, im in enumerate(imgs):
new_im.paste(im, (width*(i%4), height*(i//4)))
return new_im
def get_action_line(acts):
action_lst = []
for k in acts.keys():
if k != "camera" and acts[k] == 1:
action_lst.append(k)
action_lst.append(str(acts["camera"][0]))
action_lst.append(str(acts["camera"][1]))
return ",".join(action_lst)
def run_prediction(btns_choices, cam_x_input, cam_y_input):
global frame_cache, action_cache, actions_show, images_show, VIDEO_FRAMES, last_pos, CONTEXT_LEN, REFERENCE_FRAME
assert len(frame_cache) == len(action_cache)+1
if len(action_cache) >= CONTEXT_LEN - 1:
for _ in range(CONTEXT_LEN - REFERENCE_FRAME):
frame_cache.popleft()
action_cache.popleft()
model.transformer.refresh_kvcache()
_frame_iter = itertools.islice(frame_cache, 0, len(frame_cache)-1)
_act_iter = itertools.islice(action_cache, 0, len(action_cache))
_vis_act = [
torch.cat([img, act], dim=1)
for img, act in zip(_frame_iter, _act_iter)
]
_vis_act.append(frame_cache[-1])
_vis_act = torch.cat(_vis_act, dim=-1)
_, last_pos = model.transformer.prefill_for_gradio(_vis_act)
act_dict = make_action_dict(btns_choices)
act_dict['camera'] = np.array((int(cam_y_input), int(cam_x_input)))
ongoing_act = MC_ACTION_MAP.get_action_index_from_actiondict(act_dict, action_vocab_offset=8192)
ongoing_act = torch.tensor(ongoing_act).unsqueeze(0).to("cuda")
action_cache.append(ongoing_act)
with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16):
if DIAGD:
next_frame, last_pos = model.transformer.diagd_img_token_for_gradio(input_action=ongoing_act, position_id = last_pos, max_new_tokens=TOKEN_PER_IMAGE, windowsize=4)
else:
next_frame, last_pos = model.transformer.decode_img_token_for_gradio(input_action=ongoing_act, position_id = last_pos, max_new_tokens=TOKEN_PER_IMAGE + 1) # +1 to fill kvcache
last_pos = last_pos[0]
next_frame = torch.cat(next_frame, dim=-1).to("cuda")
frame_cache.append(next_frame)
next_frame = tokenizer.token2image(next_frame)
next_frame = Image.fromarray(next_frame)
if len(images_show) >= SHOW_FRAMES:
images_show.popleft()
actions_show.popleft()
btns_choices = btns_choices + [np.array((int(cam_y_input), int(cam_x_input)))]
actions_show.append(','.join(str(x) for item in btns_choices for x in (item if isinstance(item, np.ndarray) else [item])))
images_show.append(next_frame)
VIDEO_FRAMES.append(next_frame)
return next_frame, stack_images(images_show), " ".join([str(x) for x in actions_show])
def run_prediction_n_times(n, btns_1, btns_2, btns_3, btns_4, btns_5, cam_x_input, cam_y_input):
btns_choices = btns_1 + btns_2 + btns_3 + btns_4 + btns_5
if cam_x_input is None:
cam_x_input = 0
if cam_y_input is None:
cam_y_input = 0
if n is None:
n = 1
for i in range(n):
yield run_prediction(btns_choices, cam_x_input, cam_y_input)
def step_pred_source_video_right(video_path, start):
global VIDEO_FRAMES, frame_cache, action_cache, REFERENCE_FRAME, CONTEXT_LEN, REFERENCE_FRAME
VIDEO_FRAMES.clear(); frame_cache.clear(); action_cache.clear()
if start is None or start < 0 or start > MAX_FRAME:
start = 0
return step_video(video_path, start, REFERENCE_FRAME)
def on_download_button_click(fps=6):
if not VIDEO_FRAMES:
print("The frames list is empty.")
return
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4", dir="/tmp")
video_path = temp_file.name
temp_file.close()
video_writer = cv2.VideoWriter(
video_path,
cv2.VideoWriter_fourcc(*"mp4v"),
fps,
AGENT_RESOLUTION
)
for frame in VIDEO_FRAMES:
frame_np = np.array(frame)
frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
video_writer.write(frame_bgr)
video_writer.release()
GENERATED_FILES.append(video_path)
os.chmod(video_path, 0o644)
return video_path
def cleanup_files():
for video_path in GENERATED_FILES:
try:
os.remove(video_path)
print(f"Deleted file: {video_path}")
except OSError as e:
print(f"Error deleting file {video_path}: {e}")
atexit.register(cleanup_files)
def step_video(video_path, start_frame, frame_count):
global images_show, actions_show, frame_cache, action_cache, VIDEO_FRAMES, last_pos, CONTEXT_LEN, REFERENCE_FRAME
VIDEO_FRAMES = []
images_show = []
actions_show = []
video = cv2.VideoCapture(video_path)
json_data = MC_ACTION_MAP.read_jsonl(video_path[:-4]+".jsonl")
frames_tensor = []
action_cache = []
for i in range(start_frame, start_frame + frame_count):
step_action = json_data[i]
step_action, _ = MC_ACTION_MAP.json_action_to_env_action(step_action)
actions_show.append(get_action_line(step_action))
act_index = MC_ACTION_MAP.get_action_index_from_actiondict(step_action, action_vocab_offset=8192)
act_index = torch.tensor(act_index).unsqueeze(0)
action_cache.append(act_index.to("cuda"))
video.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = video.read()
try:
if not ret:
raise ValueError(f"frame {i} not ret")
cv2.cvtColor(frame, code=cv2.COLOR_BGR2RGB, dst=frame)
frame = np.asarray(np.clip(frame, 0, 255), dtype=np.uint8)
frame = cv2.resize(frame, AGENT_RESOLUTION, interpolation=cv2.INTER_LINEAR)
images_show.append(Image.fromarray(frame))
VIDEO_FRAMES.append(Image.fromarray(frame))
frames_tensor.append(torch.from_numpy(frame))
except Exception as e:
print(f"Could not read frame from video {video_path}: {e}")
video.release()
frames_tensor = torch.stack(frames_tensor, dim=0).to("cuda")
frames_tensor = frames_tensor.permute(0, 3, 1, 2).float() / 255.0
frames_tensor = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(frames_tensor)
with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16):
images_token = tokenizer.tokenize_images(frames_tensor)
images_token = rearrange(images_token, '(b t) h w -> (b t) (h w)', b=1)
frame_cache = deque(torch.split(images_token, split_size_or_sections=1, dim=0))
action_cache = deque(action_cache)
action_cache.pop()
images_show = deque(images_show)
actions_show = deque(actions_show)
actions_show.pop()
model.transformer.refresh_kvcache()
_frame_iter = itertools.islice(frame_cache, 0, len(frame_cache)-1)
_act_iter = itertools.islice(action_cache, 0, len(action_cache))
_vis_act = [
torch.cat([img, act], dim=1)
for img, act in zip(_frame_iter, _act_iter)
]
_vis_act.append(frame_cache[-1])
_vis_act = torch.cat(_vis_act, dim=-1)
_, last_pos = model.transformer.prefill_for_gradio(_vis_act)
while len(images_show) > SHOW_FRAMES:
images_show.popleft()
actions_show.popleft()
# WARNING: why dont pop actions
return stack_images(images_show), " ".join([str(x) for x in actions_show]), None
css = """
.custom-tab h2 {
font-size: 34px; /* 字体大小 */
font-weight: bold; /* 加粗字体 */
color: #ff6600; /* 字体颜色 */
text-shadow: 1px 1px 2px #000000; /* 文字阴影效果 */
}
"""
if __name__ == "__main__":
args = get_args()
if args.diagd:
DIAGD = True
WINDOWSIZE = args.window_size
cap = cv2.VideoCapture(args.scene)
global MAX_FRAME
MAX_FRAME = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 10
config = OmegaConf.load(args.config)
REFERENCE_FRAME = args.reference_frame
CONTEXT_LEN = int(config.model.params.transformer_config.params.max_position_embeddings / (TOKEN_PER_ACTION + TOKEN_PER_IMAGE))
assert CONTEXT_LEN > REFERENCE_FRAME
model = load_model(config, args.model_ckpt, gpu=True, eval_mode=True)
tokenizer = model.tokenizer
with gr.Blocks(css=css) as demo:
with gr.Tab("MineWorld", elem_classes="custom-tab"):
source_video_path = gr.Text(value=args.scene, visible=False)
with gr.Row():
source_video_actions = gr.Markdown(visible=False)
instruction = gr.Markdown("press 'Jump to start frame' to init or restart the game, you can choose different sences by modifed start frame from 0 to 4100", visible=False)
with gr.Row():
source_video_images = gr.Image(width=1280, height=360, label="last 8 frames", show_fullscreen_button = True, every=1)
with gr.Row(equal_height=True):
with gr.Column(min_width=60):
vid_frame_start = gr.Number(step=1, value=0, info="start frame", min_width=20, show_label=False, minimum=0, maximum=MAX_FRAME)
# vid_num_frames = gr.Number(step=1, value=4, label="num_frames", min_width=50)
# with gr.Column(min_width=60):
run_steps = gr.Number(step=1, value=1, info="Repeat same action n times", min_width=20, minimum=1, maximum=8, show_label=False)
with gr.Column(min_width=60):
btn1 = list(FOR_BACK.keys())
btns_1 = gr.CheckboxGroup(choices=btn1, show_label=False)
vid_right_btn = gr.Button(value="Jump to start frame", size='sm')
with gr.Column(min_width=60):
btn2 = list(L_R.keys())
btns_2 = gr.CheckboxGroup(choices=btn2, show_label=False)
predict_run_btn = gr.Button(value="Run", variant="primary", size='sm')
with gr.Column(min_width=60):
btn4 = list(JUMP_SPR.keys())
btns_4 = gr.CheckboxGroup(choices=btn4, show_label=False)
download_game_btn = gr.Button("Generate Video", size='sm')
with gr.Column(min_width=60):
cam_y_input = gr.Number(step=1, value=0, info="camera Y ⬆️(-),0,(+)⬇️", min_width=20, minimum=-90, maximum=90, show_label=False)
cam_x_input = gr.Number(step=1, value=0, info="camera X ⬅️(-),0,(+)➡️", min_width=20, minimum=-90, maximum=90, show_label=False)
with gr.Row():
with gr.Column(min_width=250):
video_display = gr.Video(label="video", width=384, height=224)
with gr.Column(min_width=200):
predict_result_imgs = gr.Image(label="last generated frame",width=384, height=224)
with gr.Column(min_width=200):
with gr.Row():
btn3 = list(ATT_USE_DROP.keys())
btns_3 = gr.CheckboxGroup(choices=btn3, show_label=False)
with gr.Row():
btn5 = list(HORBAR.keys())
btns_5 = gr.CheckboxGroup(choices=btn5, show_label=False)
vid_right_btn.click(fn=step_pred_source_video_right, inputs=[source_video_path, vid_frame_start],
outputs=[source_video_images, source_video_actions, predict_result_imgs])
predict_run_btn.click(fn=run_prediction_n_times, inputs=[run_steps, btns_1, btns_2, btns_3, btns_4, btns_5, cam_x_input, cam_y_input],
outputs=[predict_result_imgs, source_video_images, source_video_actions],)
download_game_btn.click(fn=on_download_button_click, inputs=[], outputs=video_display)
demo.load(fn=step_pred_source_video_right, inputs=[source_video_path, gr.Number(value=25, visible=False)],
outputs=[source_video_images, source_video_actions, predict_result_imgs])
demo.queue()
demo.launch(server_name="0.0.0.0", max_threads=256, server_port=7861, share=True)
================================================
FILE: requirements.txt
================================================
torch==2.6.0
torchvision==0.21.0
omegaconf==2.3.0
transformers==4.48.1
opencv-python==4.11.0.86
attrs==25.3.0
diffusers==0.32.2
gradio==5.24.0
einops==0.8.1
diffusers==0.32.2
scipy==1.15.2
torch-fidelity==0.3.0
gym3==0.3.3
gym==0.26.2
scikit-learn==1.6.1
================================================
FILE: scripts/compute_metrics.sh
================================================
#!/bin/bash
VIDEO_RESULTS_ROOT_DEFAULT="videos"
METRICS_ROOT_DEFAULT="metrics_log"
JSONL_PATH_DEFAULT="validation/validation"
IDM_CKPT_DIR="checkpoints/IDM"
VIDEO_RESULTS_ROOT=${1:-$VIDEO_RESULTS_ROOT_DEFAULT}
METRICS_ROOT=${2:-$METRICS_ROOT_DEFAULT}
JSONL_PATH=${3:-$JSONL_PATH_DEFAULT}
echo "VIDEO_RESULTS_ROOT = $VIDEO_RESULTS_ROOT"
echo "METRICS_ROOT = $METRICS_ROOT"
echo "JSONL_PATH = $JSONL_PATH"
# Loop through each subdirectory in VIDEO_RESULTS_ROOT
for video_dir1 in "$VIDEO_RESULTS_ROOT"/*/; do
# Skip the 'metrics' directory
if [ -d "$video_dir1" ] && [ "$(basename "$video_dir1")" != "metrics" ]; then
# Construct the output file name based on the video directory name
fvd_output_file="$METRICS_ROOT/fvd_$(basename "$video_dir1").json"
echo $fvd_output_file
# Run the python command for each video directory
python metrics/common_metrics.py --video_dir2 $JSONL_PATH --video_length 15 --channel 3 --size "(224,384)" \
--video_dir1 "$video_dir1" --output-file "$fvd_output_file"
idm_output_file="$METRICS_ROOT/idm_$(basename "$video_dir1").json"
python metrics/IDM/inverse_dynamics_model.py --weights $IDM_CKPT_DIR/"4x_idm.weights" \
--infer-demo-num 1 --n-frames 15 \
--model $IDM_CKPT_DIR/"4x_idm.model" --video-path $video_dir1 \
--output-file "$idm_output_file" \
--jsonl-path $JSONL_PATH
fi
done
python metrics/tabulate_all_results.py --input_dir $METRICS_ROOT --output_path $METRICS_ROOT/latest_metrics.csv
================================================
FILE: scripts/inference_16f_models.sh
================================================
DATA_ROOT="validation/validation"
###########################
#### Inference 1200M models
###########################
CONFIG="configs/1200M_16f.yaml"
CKPT_PATH="checkpoints/1200M_16f.ckpt"
OUTPUT_PATH="./videos/1200M_16f200_demo1gen15_naive"
python inference.py \
--data_root $DATA_ROOT \
--config $CONFIG \
--model_ckpt $CKPT_PATH \
--demo_num 1 --frames 15 \
--accelerate-algo 'naive' \
--top_p 0.8 \
--output_dir $OUTPUT_PATH
================================================
FILE: scripts/setup_metrics.sh
================================================
# clone common_metrics_on_video_quality repository
git clone git@github.com:CIntellifusion/common_metrics_on_video_quality.git
# get IDM weights
mkdir -p checkpoints/IDM
wget https://openaipublic.blob.core.windows.net/minecraft-rl/idm/4x_idm.model -O checkpoints/IDM/4x_idm.model
wget https://openaipublic.blob.core.windows.net/minecraft-rl/idm/4x_idm.weights -O checkpoints/IDM/4x_idm.weights
================================================
FILE: utils.py
================================================
import torch
import importlib
from rich import print
from typing import Union
import numpy as np
import os
def print0(*args, **kwargs):
print(*args, **kwargs) # python -m rich.color
def tensor_to_uint8(tensor):
tensor = torch.clamp(tensor, -1.0, 1.0)
tensor = (tensor + 1.0) / 2.0
tensor = (tensor.cpu().numpy() * 255).astype(np.uint8)
return tensor
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
model = instantiate_from_config(config)
if sd is not None:
missing, unexpected = model.load_state_dict(sd, strict=False)
if len(missing) != 0:
raise ValueError(f"Missing keys: {missing}")
if gpu:
model.cuda()
if eval_mode:
model.eval()
return {"model": model}
def get_valid_dirs(dir1: str, dir2: Union[None, str] = None, dir3: Union[None, str] = None) -> Union[None, str]:
if (dir1 is not None) and os.path.isdir(dir1): return dir1
elif (dir2 is not None) and os.path.isdir(dir2): return dir2
elif (dir3 is not None) and os.path.isdir(dir3): return dir3
else: return None
def get_valid_paths(path1: str, path2: Union[None, str] = None, path3: Union[None, str] = None) -> Union[None, str]:
if (path1 is not None) and os.path.isfile(path1): return path1
elif (path2 is not None) and os.path.isfile(path2): return path2
elif (path3 is not None) and os.path.isfile(path3): return path3
else: return None
def load_model(config, ckpt, gpu, eval_mode):
if str(ckpt).endswith(".bin"):
weight = torch.load(ckpt)
elif ckpt:
weight = torch.load(ckpt, map_location="cpu")["state_dict"]
model = load_model_from_config(config.model, weight, gpu=gpu, eval_mode=eval_mode)["model"]
model.load_state_dict(weight, strict=False)
model.to(torch.float16)
return model
================================================
FILE: vae.py
================================================
import torch
import torch.nn as nn
import diffusers
from safetensors.torch import load_file as load_safetensors
from utils import print0, get_valid_paths, tensor_to_uint8
class VAE(nn.Module):
def __init__(self,
config_path: str,
ckpt_path: str,
):
super().__init__()
config_path = get_valid_paths(config_path)
print0(f"[bold magenta]\[VAE][/bold magenta] Loading VQGAN from {config_path}")
self.model = diffusers.VQModel.from_config(config_path)
ckpt_path = get_valid_paths(ckpt_path)
print0(f"[bold magenta]\[VAE][/bold magenta] Use ckpt_path: {ckpt_path}")
self.init_from_ckpt(ckpt_path)
def init_from_ckpt(
self, path: str
) -> None:
if path.endswith("ckpt"):
ckpt = torch.load(path, map_location="cpu", weights_only=False)
if "state_dict" in ckpt:
weights = ckpt["state_dict"]
else:
weights = ckpt
elif path.endswith("safetensors"):
weights = load_safetensors(path)
else:
raise NotImplementedError
missing, unexpected = self.load_state_dict(weights, strict=False)
print0(
f"[bold magenta]\[tvae.models.amused_vqvae][AutoencodingLegacy][/bold magenta] Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print0(f"[bold magenta]\[tvae.models.amused_vqvae][AutoencodingLegacy][/bold magenta] Missing Keys: {missing}")
# if len(unexpected) > 0:
# print0(f"[bold magenta]\[tvae.models.amused_vqvae][AutoencodingLegacy][/bold magenta] Unexpected Keys: {unexpected}")
@torch.no_grad()
def tokenize_images(self, x: torch.Tensor, sane_index_shape: bool = True):
h = self.model.encoder(x)
h = self.model.quant_conv(h)
if sane_index_shape:
orig_sane_index_shape = self.model.quantize.sane_index_shape
self.model.quantize.sane_index_shape = True
z_q, loss, (perplexity, min_encodings, min_encoding_indices) = self.model.quantize(h)
if sane_index_shape:
self.model.quantize.sane_index_shape = orig_sane_index_shape
return min_encoding_indices
# yang ye
@torch.no_grad()
def token2image(self, tokens):
assert tokens.max() < 8192, f"code max value is {tokens.max()}"
shape = (1, 14, 24, 64)
with torch.autocast(device_type='cuda', dtype=torch.float32):
quant = self.model.quantize.get_codebook_entry(tokens, shape)
quant2 = self.model.post_quant_conv(quant)
dec = self.model.decoder(quant2)
img = tensor_to_uint8(dec[0]).transpose(1, 2, 0)
return img